Enable dynamic resolution input for Beit (#31053)

* Initial attempt

* Updates: PR suggestions

* Interpolate the relative position bias when interpolate_pos_encoding is True

* Add slow tag for the added tests

* Add in DATA2VEC_VISION_INPUTS_DOCSTRING
This commit is contained in:
Omar Salman 2024-06-06 18:47:41 +05:00 committed by GitHub
parent 99895ae5e2
commit 681183974a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 260 additions and 20 deletions

View File

@ -137,6 +137,12 @@ class BeitEmbeddings(nn.Module):
else:
self.mask_token = None
self.patch_embeddings = BeitPatchEmbeddings(config)
self.patch_size = config.patch_size
self.image_size = (
config.image_size
if isinstance(config.image_size, collections.abc.Iterable)
else (config.image_size, config.image_size)
)
num_patches = self.patch_embeddings.num_patches
if config.use_absolute_position_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
@ -144,7 +150,55 @@ class BeitEmbeddings(nn.Module):
self.position_embeddings = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
higher resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h = height // self.patch_size
w = width // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h, w = h + 0.1, w + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings, (patch_height, patch_width) = self.patch_embeddings(
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
)
@ -158,7 +212,10 @@ class BeitEmbeddings(nn.Module):
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
if self.position_embeddings is not None:
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
if interpolate_pos_encoding:
cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width)
else:
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
@ -191,7 +248,11 @@ class BeitPatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self,
pixel_values: torch.Tensor,
position_embedding: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
@ -251,6 +312,7 @@ class BeitSelfAttention(nn.Module):
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
@ -265,7 +327,9 @@ class BeitSelfAttention(nn.Module):
# Add relative position bias if present.
if self.relative_position_bias is not None:
attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0)
attention_scores = attention_scores + self.relative_position_bias(
interpolate_pos_encoding, attention_scores.shape[2]
).unsqueeze(0)
# Add shared relative position bias if provided.
if relative_position_bias is not None:
@ -342,8 +406,11 @@ class BeitAttention(nn.Module):
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias)
self_outputs = self.attention(
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
)
attention_output = self.output(self_outputs[0], hidden_states)
@ -407,12 +474,14 @@ class BeitLayer(nn.Module):
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
head_mask,
output_attentions=output_attentions,
relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
@ -471,12 +540,21 @@ class BeitRelativePositionBias(nn.Module):
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
def forward(self) -> torch.Tensor:
def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
if interpolate_pos_encoding:
relative_position_bias = nn.functional.interpolate(
relative_position_bias.unsqueeze(1),
size=(dim_size, dim_size),
mode="bilinear",
align_corners=False,
).squeeze(1)
return relative_position_bias
class BeitEncoder(nn.Module):
@ -508,6 +586,7 @@ class BeitEncoder(nn.Module):
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
interpolate_pos_encoding: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
@ -528,9 +607,13 @@ class BeitEncoder(nn.Module):
)
else:
relative_position_bias = (
self.relative_position_bias() if self.relative_position_bias is not None else None
self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1])
if self.relative_position_bias is not None
else None
)
layer_outputs = layer_module(
hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
)
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias)
hidden_states = layer_outputs[0]
@ -607,6 +690,8 @@ BEIT_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@ -658,6 +743,7 @@ class BeitModel(BeitPreTrainedModel):
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, BeitModelOutputWithPooling]:
r"""
@ -680,7 +766,9 @@ class BeitModel(BeitPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos)
embedding_output, (patch_height, patch_width) = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder(
embedding_output,
@ -688,6 +776,7 @@ class BeitModel(BeitPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
@ -755,6 +844,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, MaskedLMOutput]:
r"""
@ -800,6 +890,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
@ -858,6 +949,7 @@ class BeitForImageClassification(BeitPreTrainedModel):
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
@ -872,6 +964,7 @@ class BeitForImageClassification(BeitPreTrainedModel):
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
@ -1215,6 +1308,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, SemanticSegmenterOutput]:
r"""
@ -1255,6 +1349,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=True, # we need the intermediate hidden states
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

View File

@ -136,6 +136,12 @@ class Data2VecVisionEmbeddings(nn.Module):
else:
self.mask_token = None
self.patch_embeddings = Data2VecVisionPatchEmbeddings(config)
self.patch_size = config.patch_size
self.image_size = (
config.image_size
if isinstance(config.image_size, collections.abc.Iterable)
else (config.image_size, config.image_size)
)
num_patches = self.patch_embeddings.num_patches
if config.use_absolute_position_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
@ -143,7 +149,55 @@ class Data2VecVisionEmbeddings(nn.Module):
self.position_embeddings = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
higher resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h = height // self.patch_size
w = width // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h, w = h + 0.1, w + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings, (patch_height, patch_width) = self.patch_embeddings(
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
)
@ -157,7 +211,10 @@ class Data2VecVisionEmbeddings(nn.Module):
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
if self.position_embeddings is not None:
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
if interpolate_pos_encoding:
cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width)
else:
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
@ -191,7 +248,11 @@ class Data2VecVisionPatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self,
pixel_values: torch.Tensor,
position_embedding: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
@ -252,6 +313,7 @@ class Data2VecVisionSelfAttention(nn.Module):
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
@ -266,7 +328,9 @@ class Data2VecVisionSelfAttention(nn.Module):
# Add relative position bias if present.
if self.relative_position_bias is not None:
attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0)
attention_scores = attention_scores + self.relative_position_bias(
interpolate_pos_encoding, attention_scores.shape[2]
).unsqueeze(0)
# Add shared relative position bias if provided.
if relative_position_bias is not None:
@ -345,8 +409,11 @@ class Data2VecVisionAttention(nn.Module):
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias)
self_outputs = self.attention(
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
)
attention_output = self.output(self_outputs[0], hidden_states)
@ -415,12 +482,14 @@ class Data2VecVisionLayer(nn.Module):
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
head_mask,
output_attentions=output_attentions,
relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
@ -480,12 +549,21 @@ class Data2VecVisionRelativePositionBias(nn.Module):
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
def forward(self) -> torch.Tensor:
def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
if interpolate_pos_encoding:
relative_position_bias = nn.functional.interpolate(
relative_position_bias.unsqueeze(1),
size=(dim_size, dim_size),
mode="bilinear",
align_corners=False,
).squeeze(1)
return relative_position_bias
# Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision
@ -518,6 +596,7 @@ class Data2VecVisionEncoder(nn.Module):
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
interpolate_pos_encoding: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
@ -538,9 +617,13 @@ class Data2VecVisionEncoder(nn.Module):
)
else:
relative_position_bias = (
self.relative_position_bias() if self.relative_position_bias is not None else None
self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1])
if self.relative_position_bias is not None
else None
)
layer_outputs = layer_module(
hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
)
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias)
hidden_states = layer_outputs[0]
@ -618,6 +701,8 @@ DATA2VEC_VISION_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@ -670,6 +755,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, Data2VecVisionModelOutputWithPooling]:
r"""
@ -692,7 +778,9 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos)
embedding_output, (patch_height, patch_width) = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder(
embedding_output,
@ -700,6 +788,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
@ -772,6 +861,7 @@ class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel):
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
@ -786,6 +876,7 @@ class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel):
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
@ -1141,6 +1232,7 @@ class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[tuple, SemanticSegmenterOutput]:
r"""
@ -1181,6 +1273,7 @@ class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=True, # we need the intermediate hidden states
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

View File

@ -545,6 +545,31 @@ class BeitModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((160, 160))
self.assertEqual(segmentation[0].shape, expected_shape)
@slow
def test_inference_interpolate_pos_encoding(self):
model_name = "microsoft/beit-base-patch16-224-pt22k"
model = BeitModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(torch_device)
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
processor = BeitImageProcessor.from_pretrained(model_name)
inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480})
pixel_values = inputs.pixel_values.to(torch_device)
# with interpolate_pos_encoding being False an exception should be raised with higher resolution
# images than what the model supports.
self.assertFalse(processor.do_center_crop)
with torch.no_grad():
with self.assertRaises(ValueError, msg="doesn't match model"):
model(pixel_values, interpolate_pos_encoding=False)
# with interpolate_pos_encoding being True the model should process the higher resolution image
# successfully and produce the expected output.
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)
expected_shape = torch.Size((1, 1801, 768))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
@require_torch
class BeitBackboneTest(unittest.TestCase, BackboneTesterMixin):

View File

@ -341,3 +341,30 @@ class Data2VecVisionModelIntegrationTest(unittest.TestCase):
expected_top2 = [model.config.label2id[i] for i in ["remote control, remote", "tabby, tabby cat"]]
self.assertEqual(logits[0].topk(2).indices.cpu().tolist(), expected_top2)
@slow
def test_inference_interpolate_pos_encoding(self):
model_name = "facebook/data2vec-vision-base-ft1k"
model = Data2VecVisionModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(
torch_device
)
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
processor = BeitImageProcessor.from_pretrained("facebook/data2vec-vision-base-ft1k")
inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480})
pixel_values = inputs.pixel_values.to(torch_device)
# with interpolate_pos_encoding being False an exception should be raised with higher resolution
# images than what the model supports.
self.assertFalse(processor.do_center_crop)
with torch.no_grad():
with self.assertRaises(ValueError, msg="doesn't match model"):
model(pixel_values, interpolate_pos_encoding=False)
# with interpolate_pos_encoding being True the model should process the higher resolution image
# successfully and produce the expected output.
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)
expected_shape = torch.Size((1, 1801, 768))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)