mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[Fix] ViViT interpolate_pos_encoding (#33815)
* fix:test_inference_interpolate_pos_encoding * style:make style;make fixup * test: add suggestion to test_modeling_vivit * chore:add suggestions * style:make style * [run_slow] vivit * ci:slow test fix * [run_slow] vivit
This commit is contained in:
parent
8635802af9
commit
68a2b50069
@ -104,9 +104,10 @@ class VivitEmbeddings(nn.Module):
|
||||
torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size)
|
||||
)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.patch_size = config.tubelet_size[1:]
|
||||
self.config = config
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
# Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
@ -129,8 +130,8 @@ class VivitEmbeddings(nn.Module):
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
new_height = height // self.patch_size[0]
|
||||
new_width = width // self.patch_size[1]
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
|
@ -359,12 +359,12 @@ class VivitModelIntegrationTest(unittest.TestCase):
|
||||
# allowing to interpolate the pre-trained position embeddings in order to use
|
||||
# the model on higher resolutions. The DINO model by Facebook AI leverages this
|
||||
# to visualize self-attention on higher resolution images.
|
||||
model = VivitModel.from_pretrained("google/vivit-b-16x2").to(torch_device)
|
||||
model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400").to(torch_device)
|
||||
|
||||
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2")
|
||||
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
|
||||
video = prepare_video()
|
||||
inputs = image_processor(
|
||||
video, size={"shortest_edge": 480}, crop_size={"height": 480, "width": 480}, return_tensors="pt"
|
||||
video, size={"shortest_edge": 480}, crop_size={"height": 232, "width": 232}, return_tensors="pt"
|
||||
)
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user