Enable dynamic resolution for vivit (#30630)

* feat: enable dynamic resolution for vivit

* fix: formatting

* remove: print statement for testing

* Update src/transformers/models/vivit/modeling_vivit.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/vivit/modeling_vivit.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/vivit/modeling_vivit.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/vivit/test_modeling_vivit.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/vivit/test_modeling_vivit.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/vivit/modeling_vivit.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/vivit/test_modeling_vivit.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/vivit/modeling_vivit.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/vivit/modeling_vivit.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/vivit/modeling_vivit.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/vivit/modeling_vivit.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix: style check

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Jacky Lee 2024-05-09 03:23:39 -07:00 committed by GitHub
parent 60293bd210
commit 8c5b3c19cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 84 additions and 11 deletions

View File

@ -67,11 +67,12 @@ class VivitTubeletEmbeddings(nn.Module):
config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size
)
def forward(self, pixel_values):
def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
batch_size, num_frames, num_channels, height, width = pixel_values.shape
if height != self.image_size or width != self.image_size:
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
f"Image image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
# permute to (batch_size, num_channels, num_frames, height, width)
@ -102,16 +103,50 @@ class VivitEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def forward(self, pixel_values):
batch_size = pixel_values.shape[0]
embeddings = self.patch_embeddings(pixel_values)
def interpolate_pos_encoding(self, embeddings, height, width):
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
batch_size, num_frames, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
cls_tokens = self.cls_token.tile([batch_size, 1, 1])
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
# add positional encoding to each token
embeddings = embeddings + self.position_embeddings
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
@ -437,6 +472,8 @@ VIVIT_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@ -482,6 +519,7 @@ class VivitModel(VivitPreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPooling]:
r"""
@ -571,7 +609,7 @@ class VivitModel(VivitPreTrainedModel):
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(pixel_values)
embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
encoder_outputs = self.encoder(
embedding_output,
@ -596,8 +634,18 @@ class VivitModel(VivitPreTrainedModel):
@add_start_docstrings(
"""ViViT Transformer model with a video classification head on top (a linear layer on top of the final hidden state of the
[CLS] token) e.g. for Kinetics-400.""",
"""
ViViT Transformer model with a video classification head on top (a linear layer on top of the final hidden state of the
[CLS] token) e.g. for Kinetics-400.
<Tip>
Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""",
VIVIT_START_DOCSTRING,
)
class VivitForVideoClassification(VivitPreTrainedModel):
@ -622,6 +670,7 @@ class VivitForVideoClassification(VivitPreTrainedModel):
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], ImageClassifierOutput]:
r"""
@ -715,6 +764,7 @@ class VivitForVideoClassification(VivitPreTrainedModel):
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

View File

@ -353,3 +353,26 @@ class VivitModelIntegrationTest(unittest.TestCase):
expected_slice = torch.tensor([-0.9498, 2.7971, -1.4049, 0.1024, -1.8353]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
# Vivit models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
model = VivitModel.from_pretrained("google/vivit-b-16x2").to(torch_device)
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2")
video = prepare_video()
inputs = image_processor(
video, size={"shortest_edge": 480}, crop_size={"height": 480, "width": 480}, return_tensors="pt"
)
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)
# verify the logits shape
expected_shape = torch.Size((1, 3137, 768))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)