diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 58b28866646..f972e021f3e 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -41,6 +41,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from ...utils.backbone_utils import BackboneMixin from .configuration_beit import BeitConfig @@ -150,41 +151,46 @@ class BeitEmbeddings(nn.Module): self.position_embeddings = None self.dropout = nn.Dropout(config.hidden_dropout_prob) + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows the model to interpolate the pre-trained position encodings so that it can be used on - higher resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ + num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] + class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] - dim = embeddings.shape[-1] - h = height // self.patch_size - w = width // self.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h, w = h + 0.1, w + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) - if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]: - raise ValueError("Width or height does not match with the interpolated position embeddings") patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward( self, @@ -566,7 +572,7 @@ class BeitRelativePositionBias(nn.Module): old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2) new_sub_table = nn.functional.interpolate( - old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear" + old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear" ) new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1) diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 46e3a6005b0..2392961037f 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch BLIP model.""" -import math import warnings from dataclasses import dataclass from typing import Any, Optional, Tuple, Union @@ -33,6 +32,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel @@ -232,38 +232,46 @@ class BlipVisionEmbeddings(nn.Module): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ + num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embedding.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: - return self.position_embedding + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] - class_pos_embed = self.position_embedding[:, 0, :] - patch_pos_embed = self.position_embedding[:, 1:, :] dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index fba4c98696a..8c3b5254ea8 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -38,6 +38,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM from .configuration_blip_2 import Blip2Config, Blip2QFormerConfig, Blip2VisionConfig @@ -198,38 +199,46 @@ class Blip2VisionEmbeddings(nn.Module): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ + num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embedding.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: - return self.position_embedding + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] - class_pos_embed = self.position_embedding[:, 0, :] - patch_pos_embed = self.position_embedding[:, 1:, :] dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index fca47c524e5..4d252ce1f19 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -39,6 +39,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_data2vec_vision import Data2VecVisionConfig @@ -149,41 +150,46 @@ class Data2VecVisionEmbeddings(nn.Module): self.position_embeddings = None self.dropout = nn.Dropout(config.hidden_dropout_prob) + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows the model to interpolate the pre-trained position encodings so that it can be used on - higher resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ + num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] + class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] - dim = embeddings.shape[-1] - h = height // self.patch_size - w = width // self.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h, w = h + 0.1, w + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) - if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]: - raise ValueError("Width or height does not match with the interpolated position embeddings") patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward( self, @@ -575,7 +581,7 @@ class Data2VecVisionRelativePositionBias(nn.Module): old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2) new_sub_table = nn.functional.interpolate( - old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear" + old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear" ) new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1) diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 0f5bef5710a..03194c15d98 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -40,6 +40,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_deit import DeiTConfig @@ -77,39 +78,43 @@ class DeiTEmbeddings(nn.Module): def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and 2 class embeddings. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ - # return self.position_embeddings num_patches = embeddings.shape[1] - 2 num_positions = self.position_embeddings.shape[1] - 2 - if num_patches == num_positions and height == width: + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0, :] - dist_pos_embed = self.position_embeddings[:, 1, :] - patch_pos_embed = self.position_embeddings[:, 2:, :] + class_and_dist_pos_embed = self.position_embeddings[:, :2] + patch_pos_embed = self.position_embeddings[:, 2:] + dim = embeddings.shape[-1] - h0 = height // self.patch_size - w0 = width // self.patch_size - # # we add a small number to avoid floating point error in the interpolation - # # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), dist_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + return torch.cat((class_and_dist_pos_embed, patch_pos_embed), dim=1) def forward( self, diff --git a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py index a1f34bfe40d..dca17adf2b0 100644 --- a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py @@ -27,7 +27,13 @@ from ....activations import ACT2FN from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ....modeling_utils import PreTrainedModel from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer -from ....utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ....utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + torch_int, +) from ....utils.backbone_utils import load_backbone from .configuration_vit_hybrid import ViTHybridConfig @@ -60,41 +66,49 @@ class ViTHybridEmbeddings(nn.Module): num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size self.config = config + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] + + class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] - height = height // self.config.patch_size - width = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - height, width = height + 0.1, width + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) - if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: - raise ValueError(f"Invalid height or width: {height}, {width}") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward( self, diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 3a7959c27d8..160c5ae69f3 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -38,6 +38,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from ...utils.backbone_utils import BackboneMixin from .configuration_dinov2 import Dinov2Config @@ -71,42 +72,48 @@ class Dinov2Embeddings(nn.Module): num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size self.config = config def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] + + class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] - height = height // self.config.patch_size - width = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - height, width = height + 0.1, width + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) target_dtype = patch_pos_embed.dtype patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.to(dtype=torch.float32), - scale_factor=(float(height / math.sqrt(num_positions)), float(width / math.sqrt(num_positions))), + patch_pos_embed.to(torch.float32), + size=(new_height, new_width), mode="bicubic", align_corners=False, ).to(dtype=target_dtype) - if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: - raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 115808a6b11..8d639131b84 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -166,38 +166,49 @@ class DonutSwinEmbeddings(nn.Module): self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] + + class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward( self, diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index b3e4b86a2a4..1587493643e 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -152,7 +152,7 @@ class DPTViTHybridEmbeddings(nn.Module): posemb_tok = posemb[:, :start_index] posemb_grid = posemb[0, start_index:] - old_grid_size = int(math.sqrt(len(posemb_grid))) + old_grid_size = torch_int(len(posemb_grid) ** 0.5) posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2) posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear") @@ -626,7 +626,7 @@ class DPTReassembleStage(nn.Module): if patch_height is not None and patch_width is not None: hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels) else: - size = int(math.sqrt(sequence_length)) + size = torch_int(sequence_length**0.5) hidden_state = hidden_state.reshape(batch_size, size, size, num_channels) hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 7f7ef7dfdda..589385dffec 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -34,6 +34,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_flava import ( FlavaConfig, @@ -259,42 +260,49 @@ class FlavaImageEmbeddings(nn.Module): num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size self.config = config + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/image_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ - npatch = embeddings.shape[1] - 1 - num_pos = self.position_embeddings.shape[1] - 1 - if npatch == num_pos and height == width: + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] + + class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] - num_h_patches = height // self.config.patch_size - num_w_patches = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1 + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, int(math.sqrt(num_pos)), int(math.sqrt(num_pos)), dim).permute(0, 3, 1, 2), - scale_factor=(num_h_patches / math.sqrt(num_pos), num_w_patches / math.sqrt(num_pos)), + patch_pos_embed, + size=(new_height, new_width), mode="bicubic", align_corners=False, ) - if int(num_h_patches) != patch_pos_embed.shape[-2] or int(num_w_patches) != patch_pos_embed.shape[-1]: - raise ValueError( - f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the " - f"shape of position embedding ({patch_pos_embed.shape[-2], patch_pos_embed.shape[-1]})" - ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward( self, diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 2a0d4f3c0e4..3a2ccab8429 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -15,7 +15,6 @@ """PyTorch GroupViT model.""" import collections.abc -import math from dataclasses import dataclass from typing import Any, Optional, Tuple, Union @@ -34,6 +33,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_groupvit import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig @@ -365,39 +365,44 @@ class GroupViTVisionEmbeddings(nn.Module): self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches, config.hidden_size)) self.dropout = nn.Dropout(config.dropout) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.patch_size = config.patch_size self.config = config def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and no class embeddings. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ - npatch = embeddings.shape[1] - if npatch == self.position_embeddings.shape[1] and height == width: + num_patches = embeddings.shape[1] + num_positions = self.position_embeddings.shape[1] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings + patch_pos_embed = self.position_embeddings - num_original_pos_embed = patch_pos_embed.shape[1] + dim = embeddings.shape[-1] - feat_height = height // self.config.patch_size - feat_width = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - feat_height, feat_width = feat_height + 0.1, feat_width + 0.1 - original_height = original_width = math.sqrt(num_original_pos_embed) - reshaped_patch_pos_embed = patch_pos_embed.reshape(1, int(original_height), int(original_width), dim).permute( - 0, 3, 1, 2 - ) - scale_factor = (feat_height / original_height, feat_width / original_width) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( - reshaped_patch_pos_embed, - scale_factor=scale_factor, + patch_pos_embed, + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed diff --git a/src/transformers/models/hiera/modeling_hiera.py b/src/transformers/models/hiera/modeling_hiera.py index 2b07efe347b..de327eb91d2 100644 --- a/src/transformers/models/hiera/modeling_hiera.py +++ b/src/transformers/models/hiera/modeling_hiera.py @@ -38,6 +38,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from ...utils.backbone_utils import BackboneMixin from .configuration_hiera import HieraConfig @@ -320,46 +321,48 @@ class HieraEmbeddings(nn.Module): self, embeddings: torch.Tensor, pos_embeds: torch.Tensor, height: int, width: int ) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing, no class embeddings, and different patch strides. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] num_positions = pos_embeds.shape[1] - if num_patches == num_positions and height == width: + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return pos_embeds + dim = embeddings.shape[-1] - h0 = height // self.patch_stride[0] - w0 = width // self.patch_stride[1] - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - pos_embeds = pos_embeds.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_stride[0] + new_width = width // self.patch_stride[1] + + sqrt_num_positions = torch_int(num_positions**0.5) + pos_embeds = pos_embeds.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) pos_embeds = pos_embeds.permute(0, 3, 1, 2) + pos_embeds = nn.functional.interpolate( pos_embeds, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) - if int(h0) != pos_embeds.shape[-2] or int(w0) != pos_embeds.shape[-1]: - raise ValueError("The interpolated position encoding does not have the right size") + pos_embeds = pos_embeds.permute(0, 2, 3, 1).view(1, -1, dim) return pos_embeds def get_position_embedding( self, embeddings: torch.Tensor, height: int, width: int, interpolate_pos_encoding: bool ) -> torch.FloatTensor: - position_embeddings = self.position_embeddings - position_embeddings = ( - self.interpolate_pos_encoding(embeddings, position_embeddings, height, width) + return ( + self.interpolate_pos_encoding(embeddings, self.position_embeddings, height, width) if interpolate_pos_encoding - else position_embeddings + else self.position_embeddings ) - return position_embeddings def forward( self, diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index f59f72a6699..ba77afe9f7c 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -38,6 +38,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig @@ -102,38 +103,46 @@ class InstructBlipVisionEmbeddings(nn.Module): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ + num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embedding.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: - return self.position_embedding + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] - class_pos_embed = self.position_embedding[:, 0, :] - patch_pos_embed = self.position_embedding[:, 1:, :] dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 701402241d4..8cb813e0ac5 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -44,6 +44,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM from .configuration_instructblipvideo import ( @@ -110,38 +111,46 @@ class InstructBlipVideoVisionEmbeddings(nn.Module): self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ + num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embedding.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: - return self.position_embedding + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] - class_pos_embed = self.position_embedding[:, 0, :] - patch_pos_embed = self.position_embedding[:, 1:, :] dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index ef607ec8117..9a40e050459 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -29,6 +29,7 @@ from ...file_utils import ModelOutput from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...utils import torch_int from ...utils.backbone_utils import BackboneMixin from .configuration_maskformer_swin import MaskFormerSwinConfig @@ -162,38 +163,48 @@ class MaskFormerSwinEmbeddings(nn.Module): self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] + + class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values, interpolate_pos_encoding): _, num_channels, height, width = pixel_values.shape diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index d398d3d8c4d..b6c233c7611 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -37,6 +37,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_perceiver import PerceiverConfig @@ -2767,13 +2768,19 @@ class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding): def interpolate_pos_encoding(self, position_embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: num_positions = position_embeddings.shape[0] - new_height = new_width = math.sqrt(num_positions) - position_embeddings = position_embeddings.reshape( - 1, int(new_height), int(new_width), self._num_channels - ).permute(0, 3, 1, 2) + new_height = new_width = torch_int(num_positions**0.5) + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and height == new_height and width == new_width: + return position_embeddings + + position_embeddings = position_embeddings.reshape(1, new_height, new_width, self._num_channels).permute( + 0, 3, 1, 2 + ) + position_embeddings = nn.functional.interpolate( position_embeddings, - scale_factor=(height / new_height, width / new_width), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) @@ -2787,7 +2794,6 @@ class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding): if interpolate_pos_encoding: height, width = input_size - height, width = height + 0.1, width + 0.1 position_embeddings = self.interpolate_pos_encoding(position_embeddings, height, width) if batch_size is not None: diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py index 306cc13122d..7befa4dad02 100755 --- a/src/transformers/models/pvt/modeling_pvt.py +++ b/src/transformers/models/pvt/modeling_pvt.py @@ -123,7 +123,9 @@ class PvtPatchEmbeddings(nn.Module): def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: num_patches = height * width - if num_patches == self.config.image_size * self.config.image_size: + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == self.config.image_size * self.config.image_size: return self.position_embeddings embeddings = embeddings.reshape(1, height, width, -1).permute(0, 3, 1, 2) interpolated_embeddings = F.interpolate(embeddings, size=(height, width), mode="bilinear") diff --git a/src/transformers/models/seggpt/modeling_seggpt.py b/src/transformers/models/seggpt/modeling_seggpt.py index 3b460c0d95e..174aeaad00a 100644 --- a/src/transformers/models/seggpt/modeling_seggpt.py +++ b/src/transformers/models/seggpt/modeling_seggpt.py @@ -15,7 +15,6 @@ """PyTorch SegGpt model.""" import collections.abc -import math from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union @@ -32,6 +31,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_seggpt import SegGptConfig @@ -155,9 +155,10 @@ class SegGptEmbeddings(nn.Module): def interpolate_pos_encoding(self, height: int, width: int) -> torch.Tensor: patch_pos_embed = self.position_embeddings[:, 1:] num_patches = patch_pos_embed.shape[1] - pretrain_patch_size = int(math.sqrt(num_patches)) + pretrain_patch_size = torch_int(num_patches**0.5) - if pretrain_patch_size != height or pretrain_patch_size != width: + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if torch.jit.is_tracing() or pretrain_patch_size != height or pretrain_patch_size != width: patch_pos_embed = F.interpolate( patch_pos_embed.reshape(1, pretrain_patch_size, pretrain_patch_size, -1).permute(0, 3, 1, 2), size=(height, width), diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 3439aa49dcb..1d35d1d44cf 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -38,6 +38,7 @@ from ...utils import ( is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, + torch_int, ) from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig @@ -269,38 +270,38 @@ class SiglipVisionEmbeddings(nn.Module): def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method is an adapted method for SigLIP (due to SigLIP not having class embedding unlike other ViTs) - that allows the model to interpolate the pre-trained position encodings such that it can be usable on - higher resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and no class embeddings. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ - position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] - num_positions = position_embeddings.shape[1] - if num_patches == num_positions and height == width: - return position_embeddings + num_positions = self.position_embeddings.shape[1] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + patch_pos_embed = self.position_embeddings dim = embeddings.shape[-1] - height = height // self.patch_size - width = width // self.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - height, width = height + 0.1, width + 0.1 - patch_pos_embed = position_embeddings.reshape( - 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim - ) + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) - if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: - raise ValueError("Width or height does not match with the interpolated position embeddings") patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 8813d555968..45383a36d9b 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -251,38 +251,49 @@ class SwinEmbeddings(nn.Module): self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] + + class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward( self, diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index b3e037c60b9..0c30e739a48 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -36,6 +36,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from ...utils.backbone_utils import BackboneMixin from .configuration_swinv2 import Swinv2Config @@ -293,38 +294,49 @@ class Swinv2Embeddings(nn.Module): self.norm = nn.LayerNorm(config.embed_dim) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] + + class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward( self, diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 7897555a6ba..76ebd18ed32 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -38,6 +38,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_vit import ViTConfig @@ -70,40 +71,48 @@ class ViTEmbeddings(nn.Module): num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size self.config = config def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] + + class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) - assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward( self, diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index e85d996f47b..f6444999ac1 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -35,6 +35,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, logging, replace_return_docstrings, + torch_int, ) from .configuration_vit_mae import ViTMAEConfig @@ -206,6 +207,7 @@ class ViTMAEEmbeddings(nn.Module): self.position_embeddings = nn.Parameter( torch.zeros(1, self.num_patches + 1, config.hidden_size), requires_grad=False ) + self.patch_size = config.patch_size self.config = config self.initialize_weights() @@ -223,40 +225,46 @@ class ViTMAEEmbeddings(nn.Module): # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range) + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ + num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0, :] - patch_pos_embed = self.position_embeddings[:, 1:, :] + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) - if int(h0) != patch_pos_embed.shape[-2] or int(w0) != patch_pos_embed.shape[-1]: - raise ValueError("Width or height does not match with the interpolated position embeddings") + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def random_masking(self, sequence, noise=None): """ diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index c89370be5c0..b962ac597da 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -27,7 +27,13 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) from .configuration_vit_msn import ViTMSNConfig @@ -52,42 +58,49 @@ class ViTMSNEmbeddings(nn.Module): num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size self.config = config + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] + + class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] - patch_window_height = height // self.config.patch_size - patch_window_width = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - patch_window_height, patch_window_width = patch_window_height + 0.1, patch_window_width + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=( - patch_window_height / math.sqrt(num_positions), - patch_window_width / math.sqrt(num_positions), - ), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward( self, diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index d35c3eb7977..972040264fe 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -26,7 +26,13 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) from .configuration_vivit import VivitConfig @@ -100,37 +106,46 @@ class VivitEmbeddings(nn.Module): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config - def interpolate_pos_encoding(self, embeddings, height, width): + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] - 1 num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] + + class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) def forward(self, pixel_values, interpolate_pos_encoding: bool = False): batch_size, num_frames, num_channels, height, width = pixel_values.shape diff --git a/tests/models/hiera/test_modeling_hiera.py b/tests/models/hiera/test_modeling_hiera.py index 4319e1eb0f4..b118d6db5af 100644 --- a/tests/models/hiera/test_modeling_hiera.py +++ b/tests/models/hiera/test_modeling_hiera.py @@ -578,7 +578,7 @@ class HieraModelIntegrationTest(unittest.TestCase): self.assertEqual(outputs.last_hidden_state.shape, expected_shape) expected_slice = torch.tensor( - [[1.8522, 0.1532, 0.3849], [2.7352, -0.1941, 0.1848], [1.5859, -0.0773, 0.0168]] + [[1.7853, 0.0690, 0.3177], [2.6853, -0.2334, 0.0889], [1.5445, -0.1515, -0.0300]] ).to(torch_device) self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))