Add dynamic resolution input/interpolate position embedding to deit (#31131)

* Added interpolate pos encoding feature and test to deit

* Added interpolate pos encoding feature and test for deit TF model

* readded accidentally delted test for multi_gpu

* storing only patch_size instead of entire config and removed commented code

* Update modeling_tf_deit.py to remove extra line

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Kristen Pereira 2024-06-04 05:29:01 -04:00 committed by GitHub
parent d64e4da713
commit de460e28e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 161 additions and 14 deletions

View File

@ -73,9 +73,53 @@ class DeiTEmbeddings(nn.Module):
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.patch_size
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
# return self.position_embeddings
num_patches = embeddings.shape[1] - 2
num_positions = self.position_embeddings.shape[1] - 2
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0, :]
dist_pos_embed = self.position_embeddings[:, 1, :]
patch_pos_embed = self.position_embeddings[:, 2:, :]
dim = embeddings.shape[-1]
h0 = height // self.patch_size
w0 = width // self.patch_size
# # we add a small number to avoid floating point error in the interpolation
# # see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), dist_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_length, _ = embeddings.size()
if bool_masked_pos is not None:
@ -85,9 +129,16 @@ class DeiTEmbeddings(nn.Module):
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
embeddings = embeddings + self.position_embeddings
position_embedding = self.position_embeddings
if interpolate_pos_encoding:
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
embeddings = embeddings + position_embedding
embeddings = self.dropout(embeddings)
return embeddings
@ -120,10 +171,6 @@ class DeiTPatchEmbeddings(nn.Module):
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x
@ -480,6 +527,8 @@ DEIT_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
"""
@ -528,6 +577,7 @@ class DeiTModel(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
@ -554,7 +604,9 @@ class DeiTModel(DeiTPreTrainedModel):
if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder(
embedding_output,
@ -635,6 +687,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, MaskedImageModelingOutput]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
@ -674,6 +727,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
sequence_output = outputs[0]
@ -742,6 +796,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -784,6 +839,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
sequence_output = outputs[0]
@ -901,6 +957,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[tuple, DeiTForImageClassificationWithTeacherOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -910,6 +967,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
sequence_output = outputs[0]

View File

@ -146,9 +146,42 @@ class TFDeiTEmbeddings(keras.layers.Layer):
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
num_patches = embeddings.shape[1] - 2
num_positions = self.position_embeddings.shape[1] - 2
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0, :]
dist_pos_embed = self.position_embeddings[:, 1, :]
patch_pos_embed = self.position_embeddings[:, 2:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# # we add a small number to avoid floating point error in the interpolation
# # see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = tf.reshape(
patch_pos_embed, (1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
)
patch_pos_embed = tf.image.resize(patch_pos_embed, size=(int(h0), int(w0)), method="bicubic")
patch_pos_embed = tf.transpose(patch_pos_embed, perm=[0, 2, 3, 1])
patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, dim))
return tf.concat(
[tf.expand_dims(class_pos_embed, axis=0), tf.expand_dims(dist_pos_embed, axis=0), patch_pos_embed], axis=1
)
def call(
self, pixel_values: tf.Tensor, bool_masked_pos: tf.Tensor | None = None, training: bool = False
self,
pixel_values: tf.Tensor,
bool_masked_pos: tf.Tensor | None = None,
training: bool = False,
interpolate_pos_encoding: bool = False,
) -> tf.Tensor:
_, height, width, _ = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values)
batch_size, seq_length, _ = shape_list(embeddings)
@ -162,7 +195,11 @@ class TFDeiTEmbeddings(keras.layers.Layer):
cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
distillation_tokens = tf.repeat(self.distillation_token, repeats=batch_size, axis=0)
embeddings = tf.concat((cls_tokens, distillation_tokens, embeddings), axis=1)
embeddings = embeddings + self.position_embeddings
position_embedding = self.position_embeddings
if interpolate_pos_encoding:
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
embeddings = embeddings + position_embedding
embeddings = self.dropout(embeddings, training=training)
return embeddings
@ -197,10 +234,7 @@ class TFDeiTPatchEmbeddings(keras.layers.Layer):
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if tf.executing_eagerly() and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
x = self.projection(pixel_values)
batch_size, height, width, num_channels = shape_list(x)
x = tf.reshape(x, (batch_size, height * width, num_channels))
@ -599,6 +633,7 @@ class TFDeiTMainLayer(keras.layers.Layer):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -621,7 +656,12 @@ class TFDeiTMainLayer(keras.layers.Layer):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask)
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, training=training)
embedding_output = self.embeddings(
pixel_values,
bool_masked_pos=bool_masked_pos,
training=training,
interpolate_pos_encoding=interpolate_pos_encoding,
)
encoder_outputs = self.encoder(
embedding_output,
@ -705,6 +745,8 @@ DEIT_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@ -741,6 +783,7 @@ class TFDeiTModel(TFDeiTPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False,
) -> Union[Tuple, TFBaseModelOutputWithPooling]:
outputs = self.deit(
@ -750,6 +793,7 @@ class TFDeiTModel(TFDeiTPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training,
)
return outputs
@ -869,6 +913,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False,
) -> Union[tuple, TFMaskedImageModelingOutput]:
r"""
@ -909,6 +954,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training,
)
@ -1003,6 +1049,7 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False,
) -> Union[tf.Tensor, TFImageClassifierOutput]:
r"""
@ -1046,6 +1093,7 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training,
)
@ -1126,6 +1174,7 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
training: bool = False,
) -> Union[tuple, TFDeiTForImageClassificationWithTeacherOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -1136,6 +1185,7 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
training=training,
)

View File

@ -423,6 +423,28 @@ class DeiTModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
model = DeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224").to(
torch_device
)
image_processor = self.default_image_processor
# image size is {"height": 480, "width": 640}
image = prepare_img()
image_processor.size = {"height": 480, "width": 640}
# center crop set to False so image is not center cropped to 224x224
inputs = image_processor(images=image, return_tensors="pt", do_center_crop=False).to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
# verify the logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
@slow
@require_accelerate
@require_torch_accelerator

View File

@ -293,3 +293,20 @@ class DeiTModelIntegrationTest(unittest.TestCase):
expected_slice = tf.constant([-1.0266, 0.1912, -1.2861])
self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
model = TFDeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224")
image_processor = self.default_image_processor
# image size is {"height": 480, "width": 640}
image = prepare_img()
image_processor.size = {"height": 480, "width": 640}
# center crop set to False so image is not center cropped to 224x224
inputs = image_processor(images=image, return_tensors="tf", do_center_crop=False)
# forward pass
outputs = model(**inputs, interpolate_pos_encoding=True)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)