mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
🚨 Fix torch.jit.trace
for interpolate_pos_encoding
in all vision models (#33226)
* Fix `torch.jit.tracing` for `interpolate_pos_encoding` in all vision models * Apply formatting * Add missing `self.config = config` * Fix copies * Fix hiera interpolation unit test * Formatting * Update `_import_structure` * make style * Fix docstring * Use `# Copied from` instead of utils * DeiT variable renaming (`class_and_dist_pos_embed`) * Fix Hiera `interpolate_pos_encoding`
This commit is contained in:
parent
03164ba14e
commit
c6d2848a23
@ -41,6 +41,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_beit import BeitConfig
|
||||
@ -150,41 +151,46 @@ class BeitEmbeddings(nn.Module):
|
||||
self.position_embeddings = None
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
|
||||
higher resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
dim = embeddings.shape[-1]
|
||||
h = height // self.patch_size
|
||||
w = width // self.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h, w = h + 0.1, w + 0.1
|
||||
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
dim = embeddings.shape[-1]
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
|
||||
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -566,7 +572,7 @@ class BeitRelativePositionBias(nn.Module):
|
||||
|
||||
old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
|
||||
new_sub_table = nn.functional.interpolate(
|
||||
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
|
||||
old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
|
||||
)
|
||||
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch BLIP model."""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
@ -33,6 +32,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
|
||||
from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
|
||||
@ -232,38 +232,46 @@ class BlipVisionEmbeddings(nn.Module):
|
||||
|
||||
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embedding.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embedding
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
class_pos_embed = self.position_embedding[:, 0, :]
|
||||
patch_pos_embed = self.position_embedding[:, 1:, :]
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
||||
batch_size, _, height, width = pixel_values.shape
|
||||
|
@ -38,6 +38,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
||||
from .configuration_blip_2 import Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
|
||||
@ -198,38 +199,46 @@ class Blip2VisionEmbeddings(nn.Module):
|
||||
|
||||
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embedding.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embedding
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
class_pos_embed = self.position_embedding[:, 0, :]
|
||||
patch_pos_embed = self.position_embedding[:, 1:, :]
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
||||
batch_size, _, height, width = pixel_values.shape
|
||||
|
@ -39,6 +39,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_data2vec_vision import Data2VecVisionConfig
|
||||
|
||||
@ -149,41 +150,46 @@ class Data2VecVisionEmbeddings(nn.Module):
|
||||
self.position_embeddings = None
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
|
||||
higher resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
dim = embeddings.shape[-1]
|
||||
h = height // self.patch_size
|
||||
w = width // self.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h, w = h + 0.1, w + 0.1
|
||||
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
dim = embeddings.shape[-1]
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
|
||||
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -575,7 +581,7 @@ class Data2VecVisionRelativePositionBias(nn.Module):
|
||||
|
||||
old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
|
||||
new_sub_table = nn.functional.interpolate(
|
||||
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
|
||||
old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
|
||||
)
|
||||
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
|
||||
|
||||
|
@ -40,6 +40,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_deit import DeiTConfig
|
||||
|
||||
@ -77,39 +78,43 @@ class DeiTEmbeddings(nn.Module):
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing and 2 class embeddings.
|
||||
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
# return self.position_embeddings
|
||||
num_patches = embeddings.shape[1] - 2
|
||||
num_positions = self.position_embeddings.shape[1] - 2
|
||||
|
||||
if num_patches == num_positions and height == width:
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, 0, :]
|
||||
dist_pos_embed = self.position_embeddings[:, 1, :]
|
||||
patch_pos_embed = self.position_embeddings[:, 2:, :]
|
||||
class_and_dist_pos_embed = self.position_embeddings[:, :2]
|
||||
patch_pos_embed = self.position_embeddings[:, 2:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.patch_size
|
||||
w0 = width // self.patch_size
|
||||
# # we add a small number to avoid floating point error in the interpolation
|
||||
# # see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), dist_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
return torch.cat((class_and_dist_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -27,7 +27,13 @@ from ....activations import ACT2FN
|
||||
from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ....modeling_utils import PreTrainedModel
|
||||
from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ....utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from ....utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
torch_int,
|
||||
)
|
||||
from ....utils.backbone_utils import load_backbone
|
||||
from .configuration_vit_hybrid import ViTHybridConfig
|
||||
|
||||
@ -60,41 +66,49 @@ class ViTHybridEmbeddings(nn.Module):
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.patch_size = config.patch_size
|
||||
self.config = config
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
height = height // self.config.patch_size
|
||||
width = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
height, width = height + 0.1, width + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
|
||||
raise ValueError(f"Invalid height or width: {height}, {width}")
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -38,6 +38,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_dinov2 import Dinov2Config
|
||||
@ -71,42 +72,48 @@ class Dinov2Embeddings(nn.Module):
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.patch_size = config.patch_size
|
||||
self.config = config
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
height = height // self.config.patch_size
|
||||
width = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
height, width = height + 0.1, width + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
target_dtype = patch_pos_embed.dtype
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed.to(dtype=torch.float32),
|
||||
scale_factor=(float(height / math.sqrt(num_positions)), float(width / math.sqrt(num_positions))),
|
||||
patch_pos_embed.to(torch.float32),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
).to(dtype=target_dtype)
|
||||
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
|
||||
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
batch_size, _, height, width = pixel_values.shape
|
||||
|
@ -166,38 +166,49 @@ class DonutSwinEmbeddings(nn.Module):
|
||||
|
||||
self.norm = nn.LayerNorm(config.embed_dim)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.patch_size = config.patch_size
|
||||
self.config = config
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -152,7 +152,7 @@ class DPTViTHybridEmbeddings(nn.Module):
|
||||
posemb_tok = posemb[:, :start_index]
|
||||
posemb_grid = posemb[0, start_index:]
|
||||
|
||||
old_grid_size = int(math.sqrt(len(posemb_grid)))
|
||||
old_grid_size = torch_int(len(posemb_grid) ** 0.5)
|
||||
|
||||
posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
|
||||
posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
|
||||
@ -626,7 +626,7 @@ class DPTReassembleStage(nn.Module):
|
||||
if patch_height is not None and patch_width is not None:
|
||||
hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
|
||||
else:
|
||||
size = int(math.sqrt(sequence_length))
|
||||
size = torch_int(sequence_length**0.5)
|
||||
hidden_state = hidden_state.reshape(batch_size, size, size, num_channels)
|
||||
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
|
@ -34,6 +34,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_flava import (
|
||||
FlavaConfig,
|
||||
@ -259,42 +260,49 @@ class FlavaImageEmbeddings(nn.Module):
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.patch_size = config.patch_size
|
||||
self.config = config
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/image_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
npatch = embeddings.shape[1] - 1
|
||||
num_pos = self.position_embeddings.shape[1] - 1
|
||||
if npatch == num_pos and height == width:
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
num_h_patches = height // self.config.patch_size
|
||||
num_w_patches = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed.reshape(1, int(math.sqrt(num_pos)), int(math.sqrt(num_pos)), dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(num_h_patches / math.sqrt(num_pos), num_w_patches / math.sqrt(num_pos)),
|
||||
patch_pos_embed,
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if int(num_h_patches) != patch_pos_embed.shape[-2] or int(num_w_patches) != patch_pos_embed.shape[-1]:
|
||||
raise ValueError(
|
||||
f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the "
|
||||
f"shape of position embedding ({patch_pos_embed.shape[-2], patch_pos_embed.shape[-1]})"
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -15,7 +15,6 @@
|
||||
"""PyTorch GroupViT model."""
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
@ -34,6 +33,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_groupvit import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
|
||||
|
||||
@ -365,39 +365,44 @@ class GroupViTVisionEmbeddings(nn.Module):
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.patch_size = config.patch_size
|
||||
self.config = config
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing and no class embeddings.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
npatch = embeddings.shape[1]
|
||||
if npatch == self.position_embeddings.shape[1] and height == width:
|
||||
num_patches = embeddings.shape[1]
|
||||
num_positions = self.position_embeddings.shape[1]
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
patch_pos_embed = self.position_embeddings
|
||||
num_original_pos_embed = patch_pos_embed.shape[1]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
feat_height = height // self.config.patch_size
|
||||
feat_width = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
feat_height, feat_width = feat_height + 0.1, feat_width + 0.1
|
||||
original_height = original_width = math.sqrt(num_original_pos_embed)
|
||||
reshaped_patch_pos_embed = patch_pos_embed.reshape(1, int(original_height), int(original_width), dim).permute(
|
||||
0, 3, 1, 2
|
||||
)
|
||||
scale_factor = (feat_height / original_height, feat_width / original_width)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
reshaped_patch_pos_embed,
|
||||
scale_factor=scale_factor,
|
||||
patch_pos_embed,
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return patch_pos_embed
|
||||
|
||||
|
@ -38,6 +38,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_hiera import HieraConfig
|
||||
@ -320,46 +321,48 @@ class HieraEmbeddings(nn.Module):
|
||||
self, embeddings: torch.Tensor, pos_embeds: torch.Tensor, height: int, width: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing, no class embeddings, and different patch strides.
|
||||
|
||||
Adapted from:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1]
|
||||
num_positions = pos_embeds.shape[1]
|
||||
if num_patches == num_positions and height == width:
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return pos_embeds
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.patch_stride[0]
|
||||
w0 = width // self.patch_stride[1]
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
pos_embeds = pos_embeds.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_stride[0]
|
||||
new_width = width // self.patch_stride[1]
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
pos_embeds = pos_embeds.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
pos_embeds = pos_embeds.permute(0, 3, 1, 2)
|
||||
|
||||
pos_embeds = nn.functional.interpolate(
|
||||
pos_embeds,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if int(h0) != pos_embeds.shape[-2] or int(w0) != pos_embeds.shape[-1]:
|
||||
raise ValueError("The interpolated position encoding does not have the right size")
|
||||
|
||||
pos_embeds = pos_embeds.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return pos_embeds
|
||||
|
||||
def get_position_embedding(
|
||||
self, embeddings: torch.Tensor, height: int, width: int, interpolate_pos_encoding: bool
|
||||
) -> torch.FloatTensor:
|
||||
position_embeddings = self.position_embeddings
|
||||
position_embeddings = (
|
||||
self.interpolate_pos_encoding(embeddings, position_embeddings, height, width)
|
||||
return (
|
||||
self.interpolate_pos_encoding(embeddings, self.position_embeddings, height, width)
|
||||
if interpolate_pos_encoding
|
||||
else position_embeddings
|
||||
else self.position_embeddings
|
||||
)
|
||||
return position_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -38,6 +38,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
||||
from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig
|
||||
@ -102,38 +103,46 @@ class InstructBlipVisionEmbeddings(nn.Module):
|
||||
|
||||
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embedding.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embedding
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
class_pos_embed = self.position_embedding[:, 0, :]
|
||||
patch_pos_embed = self.position_embedding[:, 1:, :]
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
||||
batch_size, _, height, width = pixel_values.shape
|
||||
|
@ -44,6 +44,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
||||
from .configuration_instructblipvideo import (
|
||||
@ -110,38 +111,46 @@ class InstructBlipVideoVisionEmbeddings(nn.Module):
|
||||
|
||||
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embedding.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
|
||||
if num_patches == num_positions and height == width:
|
||||
return self.position_embedding
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
class_pos_embed = self.position_embedding[:, 0, :]
|
||||
patch_pos_embed = self.position_embedding[:, 1:, :]
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
||||
batch_size, _, height, width = pixel_values.shape
|
||||
|
@ -29,6 +29,7 @@ from ...file_utils import ModelOutput
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils import torch_int
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_maskformer_swin import MaskFormerSwinConfig
|
||||
|
||||
@ -162,38 +163,48 @@ class MaskFormerSwinEmbeddings(nn.Module):
|
||||
|
||||
self.norm = nn.LayerNorm(config.embed_dim)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(self, pixel_values, interpolate_pos_encoding):
|
||||
_, num_channels, height, width = pixel_values.shape
|
||||
|
@ -37,6 +37,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_perceiver import PerceiverConfig
|
||||
|
||||
@ -2767,13 +2768,19 @@ class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding):
|
||||
|
||||
def interpolate_pos_encoding(self, position_embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
num_positions = position_embeddings.shape[0]
|
||||
new_height = new_width = math.sqrt(num_positions)
|
||||
position_embeddings = position_embeddings.reshape(
|
||||
1, int(new_height), int(new_width), self._num_channels
|
||||
).permute(0, 3, 1, 2)
|
||||
new_height = new_width = torch_int(num_positions**0.5)
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and height == new_height and width == new_width:
|
||||
return position_embeddings
|
||||
|
||||
position_embeddings = position_embeddings.reshape(1, new_height, new_width, self._num_channels).permute(
|
||||
0, 3, 1, 2
|
||||
)
|
||||
|
||||
position_embeddings = nn.functional.interpolate(
|
||||
position_embeddings,
|
||||
scale_factor=(height / new_height, width / new_width),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
@ -2787,7 +2794,6 @@ class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding):
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
height, width = input_size
|
||||
height, width = height + 0.1, width + 0.1
|
||||
position_embeddings = self.interpolate_pos_encoding(position_embeddings, height, width)
|
||||
|
||||
if batch_size is not None:
|
||||
|
@ -123,7 +123,9 @@ class PvtPatchEmbeddings(nn.Module):
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
num_patches = height * width
|
||||
if num_patches == self.config.image_size * self.config.image_size:
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == self.config.image_size * self.config.image_size:
|
||||
return self.position_embeddings
|
||||
embeddings = embeddings.reshape(1, height, width, -1).permute(0, 3, 1, 2)
|
||||
interpolated_embeddings = F.interpolate(embeddings, size=(height, width), mode="bilinear")
|
||||
|
@ -15,7 +15,6 @@
|
||||
"""PyTorch SegGpt model."""
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
@ -32,6 +31,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_seggpt import SegGptConfig
|
||||
|
||||
@ -155,9 +155,10 @@ class SegGptEmbeddings(nn.Module):
|
||||
def interpolate_pos_encoding(self, height: int, width: int) -> torch.Tensor:
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
num_patches = patch_pos_embed.shape[1]
|
||||
pretrain_patch_size = int(math.sqrt(num_patches))
|
||||
pretrain_patch_size = torch_int(num_patches**0.5)
|
||||
|
||||
if pretrain_patch_size != height or pretrain_patch_size != width:
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if torch.jit.is_tracing() or pretrain_patch_size != height or pretrain_patch_size != width:
|
||||
patch_pos_embed = F.interpolate(
|
||||
patch_pos_embed.reshape(1, pretrain_patch_size, pretrain_patch_size, -1).permute(0, 3, 1, 2),
|
||||
size=(height, width),
|
||||
|
@ -38,6 +38,7 @@ from ...utils import (
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
||||
|
||||
@ -269,38 +270,38 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method is an adapted method for SigLIP (due to SigLIP not having class embedding unlike other ViTs)
|
||||
that allows the model to interpolate the pre-trained position encodings such that it can be usable on
|
||||
higher resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing and no class embeddings.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
position_embeddings = self.position_embedding.weight.unsqueeze(0)
|
||||
|
||||
num_patches = embeddings.shape[1]
|
||||
num_positions = position_embeddings.shape[1]
|
||||
if num_patches == num_positions and height == width:
|
||||
return position_embeddings
|
||||
num_positions = self.position_embeddings.shape[1]
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
patch_pos_embed = self.position_embeddings
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
height, width = height + 0.1, width + 0.1
|
||||
|
||||
patch_pos_embed = position_embeddings.reshape(
|
||||
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
||||
)
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
|
||||
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return patch_pos_embed
|
||||
|
@ -251,38 +251,49 @@ class SwinEmbeddings(nn.Module):
|
||||
|
||||
self.norm = nn.LayerNorm(config.embed_dim)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.patch_size = config.patch_size
|
||||
self.config = config
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -36,6 +36,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_swinv2 import Swinv2Config
|
||||
@ -293,38 +294,49 @@ class Swinv2Embeddings(nn.Module):
|
||||
|
||||
self.norm = nn.LayerNorm(config.embed_dim)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.patch_size = config.patch_size
|
||||
self.config = config
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -38,6 +38,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_vit import ViTConfig
|
||||
|
||||
@ -70,40 +71,48 @@ class ViTEmbeddings(nn.Module):
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.patch_size = config.patch_size
|
||||
self.config = config
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -35,6 +35,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_vit_mae import ViTMAEConfig
|
||||
|
||||
@ -206,6 +207,7 @@ class ViTMAEEmbeddings(nn.Module):
|
||||
self.position_embeddings = nn.Parameter(
|
||||
torch.zeros(1, self.num_patches + 1, config.hidden_size), requires_grad=False
|
||||
)
|
||||
self.patch_size = config.patch_size
|
||||
self.config = config
|
||||
self.initialize_weights()
|
||||
|
||||
@ -223,40 +225,46 @@ class ViTMAEEmbeddings(nn.Module):
|
||||
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
||||
torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
|
||||
if num_patches == num_positions and height == width:
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, 0, :]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:, :]
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if int(h0) != patch_pos_embed.shape[-2] or int(w0) != patch_pos_embed.shape[-1]:
|
||||
raise ValueError("Width or height does not match with the interpolated position embeddings")
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def random_masking(self, sequence, noise=None):
|
||||
"""
|
||||
|
@ -27,7 +27,13 @@ from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_vit_msn import ViTMSNConfig
|
||||
|
||||
|
||||
@ -52,42 +58,49 @@ class ViTMSNEmbeddings(nn.Module):
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.patch_size = config.patch_size
|
||||
self.config = config
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
patch_window_height = height // self.config.patch_size
|
||||
patch_window_width = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
patch_window_height, patch_window_width = patch_window_height + 0.1, patch_window_width + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(
|
||||
patch_window_height / math.sqrt(num_positions),
|
||||
patch_window_width / math.sqrt(num_positions),
|
||||
),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -26,7 +26,13 @@ from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_vivit import VivitConfig
|
||||
|
||||
|
||||
@ -100,37 +106,46 @@ class VivitEmbeddings(nn.Module):
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.config = config
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings, height, width):
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
num_positions = self.position_embeddings.shape[1] - 1
|
||||
if num_patches == num_positions and height == width:
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, 0]
|
||||
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
h0, w0 = h0 + 0.1, w0 + 0.1
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
|
||||
batch_size, num_frames, num_channels, height, width = pixel_values.shape
|
||||
|
@ -578,7 +578,7 @@ class HieraModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[1.8522, 0.1532, 0.3849], [2.7352, -0.1941, 0.1848], [1.5859, -0.0773, 0.0168]]
|
||||
[[1.7853, 0.0690, 0.3177], [2.6853, -0.2334, 0.0889], [1.5445, -0.1515, -0.0300]]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
Loading…
Reference in New Issue
Block a user