mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add dynamic resolution input/interpolate position embedding to deit (#31131)
* Added interpolate pos encoding feature and test to deit * Added interpolate pos encoding feature and test for deit TF model * readded accidentally delted test for multi_gpu * storing only patch_size instead of entire config and removed commented code * Update modeling_tf_deit.py to remove extra line Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
d64e4da713
commit
de460e28e1
@ -73,9 +73,53 @@ class DeiTEmbeddings(nn.Module):
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
|
||||
# return self.position_embeddings
|
||||
num_patches = embeddings.shape[1] - 2
|
||||
num_positions = self.position_embeddings.shape[1] - 2
|
||||
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, 0, :]
|
||||
dist_pos_embed = self.position_embeddings[:, 1, :]
|
||||
patch_pos_embed = self.position_embeddings[:, 2:, :]
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.patch_size
|
||||
w0 = width // self.patch_size
|
||||
# # we add a small number to avoid floating point error in the interpolation
|
||||
# # see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), dist_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
_, _, height, width = pixel_values.shape
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
|
||||
batch_size, seq_length, _ = embeddings.size()
|
||||
|
||||
if bool_masked_pos is not None:
|
||||
@ -85,9 +129,16 @@ class DeiTEmbeddings(nn.Module):
|
||||
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
|
||||
distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
|
||||
|
||||
embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
position_embedding = self.position_embeddings
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
|
||||
|
||||
embeddings = embeddings + position_embedding
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
@ -120,10 +171,6 @@ class DeiTPatchEmbeddings(nn.Module):
|
||||
raise ValueError(
|
||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||
)
|
||||
if height != self.image_size[0] or width != self.image_size[1]:
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
@ -480,6 +527,8 @@ DEIT_INPUTS_DOCSTRING = r"""
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
"""
|
||||
|
||||
|
||||
@ -528,6 +577,7 @@ class DeiTModel(DeiTPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
||||
@ -554,7 +604,9 @@ class DeiTModel(DeiTPreTrainedModel):
|
||||
if pixel_values.dtype != expected_dtype:
|
||||
pixel_values = pixel_values.to(expected_dtype)
|
||||
|
||||
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
embedding_output = self.embeddings(
|
||||
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
@ -635,6 +687,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[tuple, MaskedImageModelingOutput]:
|
||||
r"""
|
||||
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
||||
@ -674,6 +727,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -742,6 +796,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[tuple, ImageClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@ -784,6 +839,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -901,6 +957,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[tuple, DeiTForImageClassificationWithTeacherOutput]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@ -910,6 +967,7 @@ class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
@ -146,9 +146,42 @@ class TFDeiTEmbeddings(keras.layers.Layer):
|
||||
with tf.name_scope(self.dropout.name):
|
||||
self.dropout.build(None)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor:
|
||||
num_patches = embeddings.shape[1] - 2
|
||||
num_positions = self.position_embeddings.shape[1] - 2
|
||||
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, 0, :]
|
||||
dist_pos_embed = self.position_embeddings[:, 1, :]
|
||||
patch_pos_embed = self.position_embeddings[:, 2:, :]
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# # we add a small number to avoid floating point error in the interpolation
|
||||
# # see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = tf.reshape(
|
||||
patch_pos_embed, (1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
)
|
||||
patch_pos_embed = tf.image.resize(patch_pos_embed, size=(int(h0), int(w0)), method="bicubic")
|
||||
patch_pos_embed = tf.transpose(patch_pos_embed, perm=[0, 2, 3, 1])
|
||||
patch_pos_embed = tf.reshape(patch_pos_embed, (1, -1, dim))
|
||||
|
||||
return tf.concat(
|
||||
[tf.expand_dims(class_pos_embed, axis=0), tf.expand_dims(dist_pos_embed, axis=0), patch_pos_embed], axis=1
|
||||
)
|
||||
|
||||
def call(
|
||||
self, pixel_values: tf.Tensor, bool_masked_pos: tf.Tensor | None = None, training: bool = False
|
||||
self,
|
||||
pixel_values: tf.Tensor,
|
||||
bool_masked_pos: tf.Tensor | None = None,
|
||||
training: bool = False,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> tf.Tensor:
|
||||
_, height, width, _ = pixel_values.shape
|
||||
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
batch_size, seq_length, _ = shape_list(embeddings)
|
||||
|
||||
@ -162,7 +195,11 @@ class TFDeiTEmbeddings(keras.layers.Layer):
|
||||
cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
|
||||
distillation_tokens = tf.repeat(self.distillation_token, repeats=batch_size, axis=0)
|
||||
embeddings = tf.concat((cls_tokens, distillation_tokens, embeddings), axis=1)
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
position_embedding = self.position_embeddings
|
||||
if interpolate_pos_encoding:
|
||||
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
|
||||
|
||||
embeddings = embeddings + position_embedding
|
||||
embeddings = self.dropout(embeddings, training=training)
|
||||
return embeddings
|
||||
|
||||
@ -197,10 +234,7 @@ class TFDeiTPatchEmbeddings(keras.layers.Layer):
|
||||
raise ValueError(
|
||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||
)
|
||||
if tf.executing_eagerly() and (height != self.image_size[0] or width != self.image_size[1]):
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
|
||||
x = self.projection(pixel_values)
|
||||
batch_size, height, width, num_channels = shape_list(x)
|
||||
x = tf.reshape(x, (batch_size, height * width, num_channels))
|
||||
@ -599,6 +633,7 @@ class TFDeiTMainLayer(keras.layers.Layer):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
training: bool = False,
|
||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -621,7 +656,12 @@ class TFDeiTMainLayer(keras.layers.Layer):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask)
|
||||
|
||||
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, training=training)
|
||||
embedding_output = self.embeddings(
|
||||
pixel_values,
|
||||
bool_masked_pos=bool_masked_pos,
|
||||
training=training,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
@ -705,6 +745,8 @@ DEIT_INPUTS_DOCSTRING = r"""
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
@ -741,6 +783,7 @@ class TFDeiTModel(TFDeiTPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
training: bool = False,
|
||||
) -> Union[Tuple, TFBaseModelOutputWithPooling]:
|
||||
outputs = self.deit(
|
||||
@ -750,6 +793,7 @@ class TFDeiTModel(TFDeiTPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
training=training,
|
||||
)
|
||||
return outputs
|
||||
@ -869,6 +913,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
training: bool = False,
|
||||
) -> Union[tuple, TFMaskedImageModelingOutput]:
|
||||
r"""
|
||||
@ -909,6 +954,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
training=training,
|
||||
)
|
||||
|
||||
@ -1003,6 +1049,7 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
training: bool = False,
|
||||
) -> Union[tf.Tensor, TFImageClassifierOutput]:
|
||||
r"""
|
||||
@ -1046,6 +1093,7 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
training=training,
|
||||
)
|
||||
|
||||
@ -1126,6 +1174,7 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
training: bool = False,
|
||||
) -> Union[tuple, TFDeiTForImageClassificationWithTeacherOutput]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
@ -1136,6 +1185,7 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
training=training,
|
||||
)
|
||||
|
||||
|
@ -423,6 +423,28 @@ class DeiTModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
model = DeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
|
||||
# image size is {"height": 480, "width": 640}
|
||||
image = prepare_img()
|
||||
image_processor.size = {"height": 480, "width": 640}
|
||||
# center crop set to False so image is not center cropped to 224x224
|
||||
inputs = image_processor(images=image, return_tensors="pt", do_center_crop=False).to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@require_torch_accelerator
|
||||
|
@ -293,3 +293,20 @@ class DeiTModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice = tf.constant([-1.0266, 0.1912, -1.2861])
|
||||
|
||||
self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
model = TFDeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224")
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
# image size is {"height": 480, "width": 640}
|
||||
image = prepare_img()
|
||||
image_processor.size = {"height": 480, "width": 640}
|
||||
# center crop set to False so image is not center cropped to 224x224
|
||||
inputs = image_processor(images=image, return_tensors="tf", do_center_crop=False)
|
||||
# forward pass
|
||||
outputs = model(**inputs, interpolate_pos_encoding=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = tf.TensorShape((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
Loading…
Reference in New Issue
Block a user