mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Enable dynamic resolution input for Beit (#31053)
* Initial attempt * Updates: PR suggestions * Interpolate the relative position bias when interpolate_pos_encoding is True * Add slow tag for the added tests * Add in DATA2VEC_VISION_INPUTS_DOCSTRING
This commit is contained in:
parent
99895ae5e2
commit
681183974a
@ -137,6 +137,12 @@ class BeitEmbeddings(nn.Module):
|
||||
else:
|
||||
self.mask_token = None
|
||||
self.patch_embeddings = BeitPatchEmbeddings(config)
|
||||
self.patch_size = config.patch_size
|
||||
self.image_size = (
|
||||
config.image_size
|
||||
if isinstance(config.image_size, collections.abc.Iterable)
|
||||
else (config.image_size, config.image_size)
|
||||
)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
if config.use_absolute_position_embeddings:
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||
@ -144,7 +150,55 @@ class BeitEmbeddings(nn.Module):
|
||||
self.position_embeddings = None
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
|
||||
higher resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
dim = embeddings.shape[-1]
|
||||
h = height // self.patch_size
|
||||
w = width // self.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h, w = h + 0.1, w + 0.1
|
||||
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
|
||||
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
_, _, height, width = pixel_values.shape
|
||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model"
|
||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
|
||||
embeddings, (patch_height, patch_width) = self.patch_embeddings(
|
||||
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
|
||||
)
|
||||
@ -158,7 +212,10 @@ class BeitEmbeddings(nn.Module):
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
if self.position_embeddings is not None:
|
||||
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
|
||||
if interpolate_pos_encoding:
|
||||
cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
|
||||
|
||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||
|
||||
@ -191,7 +248,11 @@ class BeitPatchEmbeddings(nn.Module):
|
||||
|
||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
position_embedding: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
@ -251,6 +312,7 @@ class BeitSelfAttention(nn.Module):
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
@ -265,7 +327,9 @@ class BeitSelfAttention(nn.Module):
|
||||
|
||||
# Add relative position bias if present.
|
||||
if self.relative_position_bias is not None:
|
||||
attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0)
|
||||
attention_scores = attention_scores + self.relative_position_bias(
|
||||
interpolate_pos_encoding, attention_scores.shape[2]
|
||||
).unsqueeze(0)
|
||||
|
||||
# Add shared relative position bias if provided.
|
||||
if relative_position_bias is not None:
|
||||
@ -342,8 +406,11 @@ class BeitAttention(nn.Module):
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias)
|
||||
self_outputs = self.attention(
|
||||
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
|
||||
)
|
||||
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
|
||||
@ -407,12 +474,14 @@ class BeitLayer(nn.Module):
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
self_attention_outputs = self.attention(
|
||||
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
|
||||
head_mask,
|
||||
output_attentions=output_attentions,
|
||||
relative_position_bias=relative_position_bias,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
@ -471,12 +540,21 @@ class BeitRelativePositionBias(nn.Module):
|
||||
|
||||
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
|
||||
|
||||
def forward(self) -> torch.Tensor:
|
||||
def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor:
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
|
||||
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
if interpolate_pos_encoding:
|
||||
relative_position_bias = nn.functional.interpolate(
|
||||
relative_position_bias.unsqueeze(1),
|
||||
size=(dim_size, dim_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).squeeze(1)
|
||||
|
||||
return relative_position_bias
|
||||
|
||||
|
||||
class BeitEncoder(nn.Module):
|
||||
@ -508,6 +586,7 @@ class BeitEncoder(nn.Module):
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@ -528,9 +607,13 @@ class BeitEncoder(nn.Module):
|
||||
)
|
||||
else:
|
||||
relative_position_bias = (
|
||||
self.relative_position_bias() if self.relative_position_bias is not None else None
|
||||
self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1])
|
||||
if self.relative_position_bias is not None
|
||||
else None
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
|
||||
)
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -607,6 +690,8 @@ BEIT_INPUTS_DOCSTRING = r"""
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
@ -658,6 +743,7 @@ class BeitModel(BeitPreTrainedModel):
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, BeitModelOutputWithPooling]:
|
||||
r"""
|
||||
@ -680,7 +766,9 @@ class BeitModel(BeitPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos)
|
||||
embedding_output, (patch_height, patch_width) = self.embeddings(
|
||||
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
@ -688,6 +776,7 @@ class BeitModel(BeitPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
sequence_output = self.layernorm(sequence_output)
|
||||
@ -755,6 +844,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, MaskedLMOutput]:
|
||||
r"""
|
||||
@ -800,6 +890,7 @@ class BeitForMaskedImageModeling(BeitPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -858,6 +949,7 @@ class BeitForImageClassification(BeitPreTrainedModel):
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, ImageClassifierOutput]:
|
||||
r"""
|
||||
@ -872,6 +964,7 @@ class BeitForImageClassification(BeitPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -1215,6 +1308,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, SemanticSegmenterOutput]:
|
||||
r"""
|
||||
@ -1255,6 +1349,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True, # we need the intermediate hidden states
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
@ -136,6 +136,12 @@ class Data2VecVisionEmbeddings(nn.Module):
|
||||
else:
|
||||
self.mask_token = None
|
||||
self.patch_embeddings = Data2VecVisionPatchEmbeddings(config)
|
||||
self.patch_size = config.patch_size
|
||||
self.image_size = (
|
||||
config.image_size
|
||||
if isinstance(config.image_size, collections.abc.Iterable)
|
||||
else (config.image_size, config.image_size)
|
||||
)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
if config.use_absolute_position_embeddings:
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||
@ -143,7 +149,55 @@ class Data2VecVisionEmbeddings(nn.Module):
|
||||
self.position_embeddings = None
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
|
||||
higher resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
dim = embeddings.shape[-1]
|
||||
h = height // self.patch_size
|
||||
w = width // self.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h, w = h + 0.1, w + 0.1
|
||||
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
|
||||
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
_, _, height, width = pixel_values.shape
|
||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model"
|
||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
|
||||
embeddings, (patch_height, patch_width) = self.patch_embeddings(
|
||||
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
|
||||
)
|
||||
@ -157,7 +211,10 @@ class Data2VecVisionEmbeddings(nn.Module):
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
if self.position_embeddings is not None:
|
||||
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
|
||||
if interpolate_pos_encoding:
|
||||
cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
|
||||
|
||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||
|
||||
@ -191,7 +248,11 @@ class Data2VecVisionPatchEmbeddings(nn.Module):
|
||||
|
||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
position_embedding: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
@ -252,6 +313,7 @@ class Data2VecVisionSelfAttention(nn.Module):
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
@ -266,7 +328,9 @@ class Data2VecVisionSelfAttention(nn.Module):
|
||||
|
||||
# Add relative position bias if present.
|
||||
if self.relative_position_bias is not None:
|
||||
attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0)
|
||||
attention_scores = attention_scores + self.relative_position_bias(
|
||||
interpolate_pos_encoding, attention_scores.shape[2]
|
||||
).unsqueeze(0)
|
||||
|
||||
# Add shared relative position bias if provided.
|
||||
if relative_position_bias is not None:
|
||||
@ -345,8 +409,11 @@ class Data2VecVisionAttention(nn.Module):
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias)
|
||||
self_outputs = self.attention(
|
||||
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
|
||||
)
|
||||
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
|
||||
@ -415,12 +482,14 @@ class Data2VecVisionLayer(nn.Module):
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
self_attention_outputs = self.attention(
|
||||
self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
|
||||
head_mask,
|
||||
output_attentions=output_attentions,
|
||||
relative_position_bias=relative_position_bias,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
@ -480,12 +549,21 @@ class Data2VecVisionRelativePositionBias(nn.Module):
|
||||
|
||||
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
|
||||
|
||||
def forward(self) -> torch.Tensor:
|
||||
def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor:
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
|
||||
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
if interpolate_pos_encoding:
|
||||
relative_position_bias = nn.functional.interpolate(
|
||||
relative_position_bias.unsqueeze(1),
|
||||
size=(dim_size, dim_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).squeeze(1)
|
||||
|
||||
return relative_position_bias
|
||||
|
||||
|
||||
# Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision
|
||||
@ -518,6 +596,7 @@ class Data2VecVisionEncoder(nn.Module):
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@ -538,9 +617,13 @@ class Data2VecVisionEncoder(nn.Module):
|
||||
)
|
||||
else:
|
||||
relative_position_bias = (
|
||||
self.relative_position_bias() if self.relative_position_bias is not None else None
|
||||
self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1])
|
||||
if self.relative_position_bias is not None
|
||||
else None
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
|
||||
)
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -618,6 +701,8 @@ DATA2VEC_VISION_INPUTS_DOCSTRING = r"""
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
@ -670,6 +755,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, Data2VecVisionModelOutputWithPooling]:
|
||||
r"""
|
||||
@ -692,7 +778,9 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos)
|
||||
embedding_output, (patch_height, patch_width) = self.embeddings(
|
||||
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
@ -700,6 +788,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
sequence_output = self.layernorm(sequence_output)
|
||||
@ -772,6 +861,7 @@ class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel):
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, ImageClassifierOutput]:
|
||||
r"""
|
||||
@ -786,6 +876,7 @@ class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@ -1141,6 +1232,7 @@ class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, SemanticSegmenterOutput]:
|
||||
r"""
|
||||
@ -1181,6 +1273,7 @@ class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True, # we need the intermediate hidden states
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
@ -545,6 +545,31 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
expected_shape = torch.Size((160, 160))
|
||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
model_name = "microsoft/beit-base-patch16-224-pt22k"
|
||||
model = BeitModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(torch_device)
|
||||
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
processor = BeitImageProcessor.from_pretrained(model_name)
|
||||
inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480})
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# with interpolate_pos_encoding being False an exception should be raised with higher resolution
|
||||
# images than what the model supports.
|
||||
self.assertFalse(processor.do_center_crop)
|
||||
with torch.no_grad():
|
||||
with self.assertRaises(ValueError, msg="doesn't match model"):
|
||||
model(pixel_values, interpolate_pos_encoding=False)
|
||||
|
||||
# with interpolate_pos_encoding being True the model should process the higher resolution image
|
||||
# successfully and produce the expected output.
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||
|
||||
expected_shape = torch.Size((1, 1801, 768))
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BeitBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||
|
@ -341,3 +341,30 @@ class Data2VecVisionModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
expected_top2 = [model.config.label2id[i] for i in ["remote control, remote", "tabby, tabby cat"]]
|
||||
self.assertEqual(logits[0].topk(2).indices.cpu().tolist(), expected_top2)
|
||||
|
||||
@slow
|
||||
def test_inference_interpolate_pos_encoding(self):
|
||||
model_name = "facebook/data2vec-vision-base-ft1k"
|
||||
model = Data2VecVisionModel.from_pretrained(model_name, **{"use_absolute_position_embeddings": True}).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
processor = BeitImageProcessor.from_pretrained("facebook/data2vec-vision-base-ft1k")
|
||||
inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480})
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# with interpolate_pos_encoding being False an exception should be raised with higher resolution
|
||||
# images than what the model supports.
|
||||
self.assertFalse(processor.do_center_crop)
|
||||
with torch.no_grad():
|
||||
with self.assertRaises(ValueError, msg="doesn't match model"):
|
||||
model(pixel_values, interpolate_pos_encoding=False)
|
||||
|
||||
# with interpolate_pos_encoding being True the model should process the higher resolution image
|
||||
# successfully and produce the expected output.
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values, interpolate_pos_encoding=True)
|
||||
|
||||
expected_shape = torch.Size((1, 1801, 768))
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
Loading…
Reference in New Issue
Block a user