Add common test for torch.export and fix some vision models (#35124)

* Add is_torch_greater_or_equal test decorator

* Add common test for torch.export

* Fix bit

* Fix focalnet

* Fix imagegpt

* Fix seggpt

* Fix swin2sr

* Enable torch.export test for vision models

* Enable test for video models

* Remove json

* Enable for hiera

* Enable for ijepa

* Fix detr

* Fic conditional_detr

* Fix maskformer

* Enable test maskformer

* Fix test for deformable detr

* Fix custom kernels for export in rt-detr and deformable-detr

* Enable test for all DPT

* Remove custom test for deformable detr

* Simplify test to use only kwargs for export

* Add comment

* Move compile_compatible_method_lru_cache to utils

* Fix beit export

* Fix deformable detr

* Fix copies data2vec<->beit

* Fix typos, update test to work with dict

* Add seed to the test

* Enable test for vit_mae

* Fix beit tests

* [run-slow] beit, bit, conditional_detr, data2vec, deformable_detr, detr, focalnet, imagegpt, maskformer, rt_detr, seggpt, swin2sr

* Add vitpose test

* Add textnet test

* Add dinov2 with registers

* Update tests/test_modeling_common.py

* Switch to torch.testing.assert_close

* Fix masformer

* Remove save-load from test

* Add dab_detr

* Add depth_pro

* Fix and test RT-DETRv2

* Fix dab_detr
This commit is contained in:
Pavel Iakubovskii 2025-02-11 11:37:31 +00:00 committed by GitHub
parent 1779f5180e
commit f42d46ccb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
77 changed files with 305 additions and 151 deletions

View File

@ -34,7 +34,7 @@ from ...modeling_outputs import (
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
@ -297,10 +297,9 @@ class BeitSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
if window_size:
self.has_relative_position_bias = bool(window_size)
if self.has_relative_position_bias:
self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
else:
self.relative_position_bias = None
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
@ -312,7 +311,7 @@ class BeitSelfAttention(nn.Module):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
relative_position_bias: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
@ -328,7 +327,7 @@ class BeitSelfAttention(nn.Module):
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Add relative position bias if present.
if self.relative_position_bias is not None:
if self.has_relative_position_bias:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
attention_scores = attention_scores + self.relative_position_bias(
@ -367,7 +366,7 @@ class BeitSdpaSelfAttention(BeitSelfAttention):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
relative_position_bias: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
@ -393,7 +392,7 @@ class BeitSdpaSelfAttention(BeitSelfAttention):
query_layer = self.transpose_for_scores(mixed_query_layer)
attn_bias = None
if self.relative_position_bias is not None:
if self.has_relative_position_bias:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
attn_bias = self.relative_position_bias(
@ -477,7 +476,7 @@ class BeitAttention(nn.Module):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
relative_position_bias: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
@ -546,7 +545,7 @@ class BeitLayer(nn.Module):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
relative_position_bias: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
@ -595,8 +594,7 @@ class BeitRelativePositionBias(nn.Module):
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
self.relative_position_indices = {}
@compile_compatible_method_lru_cache(maxsize=10)
def generate_relative_position_index(self, window_size: Tuple[int, int]) -> torch.Tensor:
"""
This method creates the relative position index, modified to support arbitrary window sizes,
@ -648,11 +646,9 @@ class BeitRelativePositionBias(nn.Module):
[new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
)
key = window_size
if key not in self.relative_position_indices.keys():
self.relative_position_indices[key] = self.generate_relative_position_index(window_size)
relative_position_index = self.generate_relative_position_index(window_size)
relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)]
relative_position_bias = new_relative_position_bias_table[self.relative_position_indices[key].view(-1)]
# patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
relative_position_bias = relative_position_bias.view(
window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
@ -675,10 +671,9 @@ class BeitEncoder(nn.Module):
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
super().__init__()
self.config = config
if config.use_shared_relative_position_bias:
self.has_relative_position_bias = config.use_shared_relative_position_bias
if self.has_relative_position_bias:
self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
else:
self.relative_position_bias = None
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
@ -701,7 +696,7 @@ class BeitEncoder(nn.Module):
output_attentions: bool = False,
output_hidden_states: bool = False,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
resolution: Optional[Tuple[int, int]] = None,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
@ -711,6 +706,15 @@ class BeitEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.has_relative_position_bias:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
relative_position_bias = self.relative_position_bias(
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)
else:
relative_position_bias = None
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
@ -719,17 +723,11 @@ class BeitEncoder(nn.Module):
hidden_states,
layer_head_mask,
output_attentions,
relative_position_bias,
interpolate_pos_encoding,
resolution,
)
else:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
relative_position_bias = (
self.relative_position_bias(
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)
if self.relative_position_bias is not None
else None
)
layer_outputs = layer_module(
hidden_states,
layer_head_mask,

View File

@ -192,7 +192,7 @@ class DynamicPad2d(nn.Module):
self.compute_padding = compute_padding
def __call__(self, input):
def forward(self, input):
# Get width and height
input_height, input_width = input.size()[-2:]

View File

@ -1735,7 +1735,11 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
# class logits + predicted bounding boxes
logits = self.class_labels_classifier(sequence_output)
reference = outputs.reference_points if return_dict else outputs[-1]
# Index [-2] is valid only if `output_attentions` and `output_hidden_states`
# are not specified, otherwise it will be another index which is hard to determine.
# Leave it as is, because it's not a common case to use
# return_dict=False + output_attentions=True / output_hidden_states=True
reference = outputs.reference_points if return_dict else outputs[-2]
reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
hs = sequence_output
@ -2105,7 +2109,7 @@ class ConditionalDetrMHAttentionMap(nn.Module):
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
if mask is not None:
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
weights = self.dropout(weights)
return weights

View File

@ -1537,7 +1537,7 @@ class DabDetrMHAttentionMap(nn.Module):
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
if mask is not None:
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
weights = self.dropout(weights)
return weights

View File

@ -32,7 +32,7 @@ from ...modeling_outputs import (
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
@ -298,10 +298,9 @@ class Data2VecVisionSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
if window_size:
self.has_relative_position_bias = bool(window_size)
if self.has_relative_position_bias:
self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
else:
self.relative_position_bias = None
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
@ -313,7 +312,7 @@ class Data2VecVisionSelfAttention(nn.Module):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
relative_position_bias: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
@ -329,7 +328,7 @@ class Data2VecVisionSelfAttention(nn.Module):
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Add relative position bias if present.
if self.relative_position_bias is not None:
if self.has_relative_position_bias:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
attention_scores = attention_scores + self.relative_position_bias(
@ -369,7 +368,7 @@ class Data2VecVisionSdpaSelfAttention(Data2VecVisionSelfAttention):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
relative_position_bias: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
@ -395,7 +394,7 @@ class Data2VecVisionSdpaSelfAttention(Data2VecVisionSelfAttention):
query_layer = self.transpose_for_scores(mixed_query_layer)
attn_bias = None
if self.relative_position_bias is not None:
if self.has_relative_position_bias:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
attn_bias = self.relative_position_bias(
@ -557,7 +556,7 @@ class Data2VecVisionLayer(nn.Module):
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
relative_position_bias: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
@ -607,8 +606,7 @@ class Data2VecVisionRelativePositionBias(nn.Module):
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
self.relative_position_indices = {}
@compile_compatible_method_lru_cache(maxsize=10)
def generate_relative_position_index(self, window_size: Tuple[int, int]) -> torch.Tensor:
"""
This method creates the relative position index, modified to support arbitrary window sizes,
@ -660,11 +658,9 @@ class Data2VecVisionRelativePositionBias(nn.Module):
[new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
)
key = window_size
if key not in self.relative_position_indices.keys():
self.relative_position_indices[key] = self.generate_relative_position_index(window_size)
relative_position_index = self.generate_relative_position_index(window_size)
relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)]
relative_position_bias = new_relative_position_bias_table[self.relative_position_indices[key].view(-1)]
# patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
relative_position_bias = relative_position_bias.view(
window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
@ -688,10 +684,9 @@ class Data2VecVisionEncoder(nn.Module):
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
super().__init__()
self.config = config
if config.use_shared_relative_position_bias:
self.has_relative_position_bias = config.use_shared_relative_position_bias
if self.has_relative_position_bias:
self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
else:
self.relative_position_bias = None
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
@ -714,7 +709,7 @@ class Data2VecVisionEncoder(nn.Module):
output_attentions: bool = False,
output_hidden_states: bool = False,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
resolution: Optional[Tuple[int, int]] = None,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
@ -724,6 +719,15 @@ class Data2VecVisionEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.has_relative_position_bias:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
relative_position_bias = self.relative_position_bias(
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)
else:
relative_position_bias = None
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
@ -732,17 +736,11 @@ class Data2VecVisionEncoder(nn.Module):
hidden_states,
layer_head_mask,
output_attentions,
relative_position_bias,
interpolate_pos_encoding,
resolution,
)
else:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
relative_position_bias = (
self.relative_position_bias(
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)
if self.relative_position_bias is not None
else None
)
layer_outputs = layer_module(
hidden_states,
layer_head_mask,

View File

@ -40,6 +40,7 @@ from ...utils import (
is_ninja_available,
is_timm_available,
is_torch_cuda_available,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
requires_backends,
@ -705,7 +706,7 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
if self.disable_custom_kernels or MultiScaleDeformableAttention is None or is_torchdynamo_compiling():
# PyTorch implementation
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
@ -1606,7 +1607,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
Args:
enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder.
padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`.
spatial_shapes (Tensor[num_feature_levels, 2]): Spatial shapes of the feature maps.
spatial_shapes (List[Tuple[int, int]]): Spatial shapes of the feature maps.
Returns:
`tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction.
@ -1786,7 +1787,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
enc_outputs_coord_logits = None
if self.config.two_stage:
object_query_embedding, output_proposals = self.gen_encoder_output_proposals(
encoder_outputs[0], ~mask_flatten, spatial_shapes
encoder_outputs[0], ~mask_flatten, spatial_shapes_list
)
# hack implementation for two-stage Deformable DETR

View File

@ -1801,7 +1801,7 @@ class DetrMHAttentionMap(nn.Module):
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
if mask is not None:
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
weights = self.dropout(weights)
return weights

View File

@ -358,23 +358,23 @@ class FocalNetModulation(nn.Module):
# pre linear projection
x = self.projection_in(hidden_state).permute(0, 3, 1, 2).contiguous()
q, ctx, self.gates = torch.split(x, (num_channels, num_channels, self.focal_level + 1), 1)
q, ctx, gates = torch.split(x, (num_channels, num_channels, self.focal_level + 1), 1)
# context aggreation
ctx_all = 0
for level in range(self.focal_level):
ctx = self.focal_layers[level](ctx)
ctx_all = ctx_all + ctx * self.gates[:, level : level + 1]
ctx_all = ctx_all + ctx * gates[:, level : level + 1]
ctx_global = self.activation(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
ctx_all = ctx_all + ctx_global * self.gates[:, self.focal_level :]
ctx_all = ctx_all + ctx_global * gates[:, self.focal_level :]
# normalize context
if self.normalize_modulator:
ctx_all = ctx_all / (self.focal_level + 1)
# focal modulation
self.modulator = self.projection_context(ctx_all)
x_out = q * self.modulator
modulator = self.projection_context(ctx_all)
x_out = q * modulator
x_out = x_out.permute(0, 2, 3, 1).contiguous()
if self.use_post_layernorm_in_modulation:
x_out = self.layernorm(x_out)

View File

@ -164,13 +164,11 @@ class ImageGPTLayerNorm(nn.Module):
self.eps = eps
self.weight = nn.Parameter(torch.Tensor(hidden_size))
def forward(self, tensor: torch.Tensor) -> tuple:
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
# input is not mean centered
return (
tensor
/ torch.sqrt(torch.mean(torch.square(tensor), axis=-1, keepdim=True) + self.eps)
* self.weight.data[..., :]
)
tensor = tensor / torch.sqrt(torch.mean(torch.square(tensor), axis=-1, keepdim=True) + self.eps)
tensor = tensor * self.weight
return tensor
class ImageGPTAttention(nn.Module):

View File

@ -1424,7 +1424,11 @@ class MaskFormerTransformerModule(nn.Module):
# repeat the queries "q c -> b q c"
batch_size = image_features.shape[0]
queries_embeddings = self.queries_embedder.weight.unsqueeze(0).repeat(batch_size, 1, 1)
inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=True)
inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=self.training)
# torch.export.export does no support requires_grad
if self.training:
inputs_embeds.requires_grad_(True)
batch_size, num_channels, height, width = image_features.shape
# rearrange both image_features and object_queries "b c h w -> b (h w) c"

View File

@ -18,7 +18,7 @@ import math
import os
import warnings
from dataclasses import dataclass
from functools import lru_cache, partial, wraps
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
@ -32,12 +32,14 @@ from ...activations import ACT2CLS, ACT2FN
from ...image_transforms import center_to_corners_format, corners_to_center_format
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import compile_compatible_method_lru_cache
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_ninja_available,
is_torch_cuda_available,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -870,7 +872,7 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
if self.disable_custom_kernels or MultiScaleDeformableAttention is None or is_torchdynamo_compiling():
# PyTorch implementation
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
@ -1590,27 +1592,6 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
)
def compile_compatible_lru_cache(*lru_args, **lru_kwargs):
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not torch.compiler.is_compiling():
# Cache the function only if the model is not being compiled
# check if the function is already cached, otherwise create it
if not hasattr(self, f"_cached_{func.__name__}"):
self.__setattr__(
f"_cached_{func.__name__}", lru_cache(*lru_args, **lru_kwargs)(func.__get__(self))
)
return self.__getattribute__(f"_cached_{func.__name__}")(*args, **kwargs)
else:
# Otherwise, just call the original function
return func(self, *args, **kwargs)
return wrapper
return decorator
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
class RTDetrMLPPredictionHead(nn.Module):
"""
@ -1728,7 +1709,7 @@ class RTDetrModel(RTDetrPreTrainedModel):
for param in self.backbone.parameters():
param.requires_grad_(True)
@compile_compatible_lru_cache(maxsize=32)
@compile_compatible_method_lru_cache(maxsize=32)
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
if spatial_shapes is None:
spatial_shapes = [

View File

@ -22,7 +22,7 @@ import math
import os
import warnings
from dataclasses import dataclass
from functools import lru_cache, partial, wraps
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
@ -34,12 +34,14 @@ from ...activations import ACT2CLS, ACT2FN
from ...image_transforms import center_to_corners_format, corners_to_center_format
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import compile_compatible_method_lru_cache
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_ninja_available,
is_torch_cuda_available,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@ -97,7 +99,7 @@ def multi_scale_deformable_attention_v2(
value_list = (
value.permute(0, 2, 3, 1)
.flatten(0, 1)
.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=-1)
.split([height * width for height, width in value_spatial_shapes], dim=-1)
)
# sampling_offsets [8, 480, 8, 12, 2]
if method == "default":
@ -226,9 +228,9 @@ class RTDetrV2MultiscaleDeformableAttention(nn.Module):
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
**kwargs,
):
# Process inputs up to sampling locations calculation using parent class logic
if position_embeddings is not None:
@ -236,7 +238,7 @@ class RTDetrV2MultiscaleDeformableAttention(nn.Module):
batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
@ -272,7 +274,7 @@ class RTDetrV2MultiscaleDeformableAttention(nn.Module):
# V2-specific attention implementation choice
output = multi_scale_deformable_attention_v2(
value, spatial_shapes, sampling_locations, attention_weights, self.n_points_list, self.method
value, spatial_shapes_list, sampling_locations, attention_weights, self.n_points_list, self.method
)
output = self.output_proj(output)
@ -1329,27 +1331,6 @@ RTDetrV2_INPUTS_DOCSTRING = r"""
"""
def compile_compatible_lru_cache(*lru_args, **lru_kwargs):
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not torch.compiler.is_compiling():
# Cache the function only if the model is not being compiled
# check if the function is already cached, otherwise create it
if not hasattr(self, f"_cached_{func.__name__}"):
self.__setattr__(
f"_cached_{func.__name__}", lru_cache(*lru_args, **lru_kwargs)(func.__get__(self))
)
return self.__getattribute__(f"_cached_{func.__name__}")(*args, **kwargs)
else:
# Otherwise, just call the original function
return func(self, *args, **kwargs)
return wrapper
return decorator
def _get_clones(partial_module, N):
return nn.ModuleList([partial_module() for i in range(N)])
@ -1669,7 +1650,7 @@ class RTDetrV2Model(RTDetrV2PreTrainedModel):
for param in self.backbone.parameters():
param.requires_grad_(True)
@compile_compatible_lru_cache(maxsize=32)
@compile_compatible_method_lru_cache(maxsize=32)
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
if spatial_shapes is None:
spatial_shapes = [

View File

@ -20,7 +20,7 @@ import torch.nn.functional as F
from torch import Tensor, nn
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils import is_torchdynamo_compiling, logging
from ...utils.backbone_utils import (
verify_backbone_config_arguments,
)
@ -404,7 +404,7 @@ def multi_scale_deformable_attention_v2(
value_list = (
value.permute(0, 2, 3, 1)
.flatten(0, 1)
.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=-1)
.split([height * width for height, width in value_spatial_shapes], dim=-1)
)
# sampling_offsets [8, 480, 8, 12, 2]
if method == "default":
@ -497,9 +497,9 @@ class RTDetrV2MultiscaleDeformableAttention(RTDetrMultiscaleDeformableAttention)
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
**kwargs,
):
# Process inputs up to sampling locations calculation using parent class logic
if position_embeddings is not None:
@ -507,7 +507,7 @@ class RTDetrV2MultiscaleDeformableAttention(RTDetrMultiscaleDeformableAttention)
batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
@ -543,7 +543,7 @@ class RTDetrV2MultiscaleDeformableAttention(RTDetrMultiscaleDeformableAttention)
# V2-specific attention implementation choice
output = multi_scale_deformable_attention_v2(
value, spatial_shapes, sampling_locations, attention_weights, self.n_points_list, self.method
value, spatial_shapes_list, sampling_locations, attention_weights, self.n_points_list, self.method
)
output = self.output_proj(output)

View File

@ -817,8 +817,11 @@ class SegGptModel(SegGptPreTrainedModel):
# and reconstructed together (In-Context Painting).
if bool_masked_pos is None:
num_patches = self.embeddings.patch_embeddings.num_patches
bool_masked_pos = torch.zeros(num_patches, dtype=torch.bool).to(pixel_values.device)
bool_masked_pos[num_patches // 2 :] = 1
bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device)
bool_masked_pos_ones = torch.ones(
num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device
)
bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones])
bool_masked_pos = bool_masked_pos.unsqueeze(0)
embedding_output = self.embeddings(
@ -975,8 +978,11 @@ class SegGptForImageSegmentation(SegGptPreTrainedModel):
if bool_masked_pos is None:
num_patches = self.model.embeddings.patch_embeddings.num_patches
bool_masked_pos = torch.zeros(num_patches, dtype=torch.bool).to(pixel_values.device)
bool_masked_pos[num_patches // 2 :] = 1
bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device)
bool_masked_pos_ones = torch.ones(
num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device
)
bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones])
bool_masked_pos = bool_masked_pos.unsqueeze(0)
outputs = self.model(

View File

@ -813,10 +813,11 @@ class Swin2SRModel(Swin2SRPreTrainedModel):
self.config = config
if config.num_channels == 3 and config.num_channels_out == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
mean = torch.tensor([0.4488, 0.4371, 0.4040]).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
mean = torch.zeros(1, 1, 1, 1)
self.register_buffer("mean", mean, persistent=False)
self.img_range = config.img_range
self.first_convolution = nn.Conv2d(config.num_channels, config.embed_dim, 3, 1, 1)
@ -851,8 +852,8 @@ class Swin2SRModel(Swin2SRPreTrainedModel):
pixel_values = nn.functional.pad(pixel_values, (0, modulo_pad_width, 0, modulo_pad_height), "reflect")
# 2. normalize
self.mean = self.mean.type_as(pixel_values)
pixel_values = (pixel_values - self.mean) * self.img_range
mean = self.mean.type_as(pixel_values)
pixel_values = (pixel_values - mean) * self.img_range
return pixel_values

View File

@ -14,6 +14,7 @@
from __future__ import annotations
import inspect
from functools import lru_cache, wraps
from typing import Callable, List, Optional, Set, Tuple, Union
import torch
@ -21,7 +22,7 @@ from packaging import version
from safetensors.torch import storage_ptr, storage_size
from torch import nn
from .utils import is_torch_greater_or_equal, is_torch_xla_available, logging
from .utils import is_torch_greater_or_equal, is_torch_xla_available, is_torchdynamo_compiling, logging
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
@ -364,3 +365,29 @@ def translate_to_torch_parallel_style(style: str):
return RowwiseParallel(input_layouts=Replicate())
else:
raise ValueError(f"Unsupported parallel style value: {style}")
def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
"""
LRU cache decorator from standard functools library, but with a workaround to disable
caching when torchdynamo is compiling. Expected to work with class methods.
"""
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not is_torchdynamo_compiling():
# Cache the function only if the model is not being compiled
# check if the function is already cached, otherwise create it
if not hasattr(self, f"_cached_{func.__name__}"):
self.__setattr__(
f"_cached_{func.__name__}", lru_cache(*lru_args, **lru_kwargs)(func.__get__(self))
)
return self.__getattribute__(f"_cached_{func.__name__}")(*args, **kwargs)
else:
# Otherwise, just call the original function
return func(self, *args, **kwargs)
return wrapper
return decorator

View File

@ -135,6 +135,7 @@ from .utils import (
is_torch_bf16_gpu_available,
is_torch_deterministic,
is_torch_fp16_available_on_device,
is_torch_greater_or_equal,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_sdpa_available,
@ -556,6 +557,21 @@ def require_torch(test_case):
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
def require_torch_greater_or_equal(version: str):
"""
Decorator marking a test that requires PyTorch version >= `version`.
These tests are skipped when PyTorch version is less than `version`.
"""
def decorator(test_case):
return unittest.skipUnless(is_torch_greater_or_equal(version), f"test requires PyTorch version >= {version}")(
test_case
)
return decorator
def require_flash_attn(test_case):
"""
Decorator marking a test that requires Flash Attention.

View File

@ -271,6 +271,7 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = BeitModelTester(self)
@ -292,6 +293,10 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="BEiT can't compile dynamic")
def test_sdpa_can_compile_dynamic(self):
pass
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@ -764,13 +769,6 @@ class BeitModelIntegrationTest(unittest.TestCase):
inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480})
pixel_values = inputs.pixel_values.to(torch_device)
# with interpolate_pos_encoding being False an exception should be raised with higher resolution
# images than what the model supports.
self.assertFalse(processor.do_center_crop)
with torch.no_grad():
with self.assertRaises(ValueError, msg="doesn't match model"):
model(pixel_values, interpolate_pos_encoding=False)
# with interpolate_pos_encoding being True the model should process the higher resolution image
# successfully and produce the expected output.
with torch.no_grad():

View File

@ -170,6 +170,7 @@ class BitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = BitModelTester(self)

View File

@ -194,6 +194,7 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
test_head_masking = False
test_missing_keys = False
zero_init_hidden_state = True
test_torch_exportable = True
# special case for head models
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -180,6 +180,7 @@ class ConvNextModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = ConvNextModelTester(self)

View File

@ -188,6 +188,7 @@ class ConvNextV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = ConvNextV2ModelTester(self)

View File

@ -159,6 +159,7 @@ class CvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = CvtModelTester(self)

View File

@ -197,6 +197,7 @@ class DabDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
test_head_masking = False
test_missing_keys = False
zero_init_hidden_state = True
test_torch_exportable = True
# special case for head models
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -200,6 +200,7 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
test_pruning = False
test_head_masking = False
test_missing_keys = False
test_torch_exportable = True
# special case for head models
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -222,6 +222,7 @@ class DeiTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = DeiTModelTester(self)

View File

@ -146,6 +146,7 @@ class DepthAnythingModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = DepthAnythingModelTester(self)

View File

@ -212,6 +212,7 @@ class DepthProModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = DepthProModelTester(self)

View File

@ -194,6 +194,7 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
test_head_masking = False
test_missing_keys = False
zero_init_hidden_state = True
test_torch_exportable = True
# special case for head models
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -216,6 +216,7 @@ class DinatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = DinatModelTester(self)

View File

@ -212,6 +212,8 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
attention_mask and seq_length.
"""
test_torch_exportable = True
all_model_classes = (
(
Dinov2Model,

View File

@ -237,6 +237,7 @@ class Dinov2WithRegistersModelTest(ModelTesterMixin, PipelineTesterMixin, unitte
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = Dinov2WithRegistersModelTester(self)

View File

@ -172,6 +172,7 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = DPTModelTester(self)

View File

@ -140,6 +140,7 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = DPTModelTester(self)

View File

@ -186,6 +186,7 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = DPTModelTester(self)

View File

@ -139,6 +139,7 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = EfficientNetModelTester(self)

View File

@ -247,6 +247,7 @@ class FocalNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = FocalNetModelTester(self)

View File

@ -152,6 +152,7 @@ class GLPNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_torch_exportable = True
def setUp(self):
self.model_tester = GLPNModelTester(self)

View File

@ -250,6 +250,7 @@ class HieraModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = HieraModelTester(self)

View File

@ -207,6 +207,7 @@ class IJepaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = IJepaModelTester(self)

View File

@ -237,6 +237,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
else {}
)
test_missing_keys = False
test_torch_exportable = True
# as ImageGPTForImageClassification isn't included in any auto mapping, we add labels here
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -205,6 +205,7 @@ class Mask2FormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
test_pruning = False
test_head_masking = False
test_missing_keys = False
test_torch_exportable = True
def setUp(self):
self.model_tester = Mask2FormerModelTester(self)

View File

@ -209,6 +209,7 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
test_head_masking = False
test_missing_keys = False
zero_init_hidden_state = True
test_torch_exportable = True
def setUp(self):
self.model_tester = MaskFormerModelTester(self)

View File

@ -181,6 +181,7 @@ class MaskFormerSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = MaskFormerSwinModelTester(self)

View File

@ -154,6 +154,7 @@ class MobileNetV1ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = MobileNetV1ModelTester(self)

View File

@ -205,6 +205,7 @@ class MobileNetV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = MobileNetV2ModelTester(self)

View File

@ -198,6 +198,7 @@ class MobileViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = MobileViTModelTester(self)

View File

@ -200,6 +200,7 @@ class MobileViTV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = MobileViTV2ModelTester(self)

View File

@ -132,6 +132,7 @@ class PoolFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
test_resize_embeddings = False
test_torchscript = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = PoolFormerModelTester(self)

View File

@ -166,6 +166,7 @@ class PvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_resize_embeddings = False
test_torchscript = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = PvtModelTester(self)

View File

@ -202,6 +202,7 @@ class PvtV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_resize_embeddings = False
test_torchscript = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = PvtV2ModelTester(self)

View File

@ -133,6 +133,7 @@ class RegNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = RegNetModelTester(self)

View File

@ -178,6 +178,7 @@ class ResNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = ResNetModelTester(self)

View File

@ -261,6 +261,7 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_head_masking = False
test_missing_keys = False
test_torch_exportable = True
# special case for head models
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -259,6 +259,7 @@ class RTDetrV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
test_pruning = False
test_head_masking = False
test_missing_keys = False
test_torch_exportable = True
# special case for head models
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -180,6 +180,7 @@ class SegformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_torch_exportable = True
def setUp(self):
self.model_tester = SegformerModelTester(self)

View File

@ -172,6 +172,8 @@ class SegGptModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_resize_embeddings = False
test_head_masking = False
test_torchscript = False
test_torch_exportable = True
pipeline_model_mapping = (
{"feature-extraction": SegGptModel, "mask-generation": SegGptModel} if is_torch_available() else {}
)

View File

@ -147,6 +147,7 @@ class SwiftFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = SwiftFormerModelTester(self)

View File

@ -240,6 +240,7 @@ class SwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = SwinModelTester(self)

View File

@ -172,6 +172,7 @@ class Swin2SRModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
test_resize_embeddings = False
test_head_masking = False
test_torchscript = False
test_torch_exportable = True
def setUp(self):
self.model_tester = Swin2SRModelTester(self)

View File

@ -226,6 +226,7 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = Swinv2ModelTester(self)

View File

@ -209,6 +209,7 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin
test_head_masking = False
test_missing_keys = False
zero_init_hidden_state = True
test_torch_exportable = True
# special case for head models
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -217,6 +217,7 @@ class TextNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
has_attentions = False
def setUp(self):

View File

@ -167,6 +167,7 @@ class TimesformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
test_torchscript = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = TimesformerModelTester(self)

View File

@ -157,6 +157,7 @@ class UperNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
test_head_masking = False
test_torchscript = False
has_attentions = False
test_torch_exportable = True
def setUp(self):
self.model_tester = UperNetModelTester(self)

View File

@ -186,6 +186,7 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
test_torchscript = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = VideoMAEModelTester(self)

View File

@ -207,6 +207,7 @@ class ViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = ViTModelTester(self)

View File

@ -174,6 +174,7 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_torchscript = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = ViTMAEModelTester(self)

View File

@ -162,6 +162,7 @@ class ViTMSNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_torchscript = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = ViTMSNModelTester(self)

View File

@ -169,6 +169,7 @@ class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = VitDetModelTester(self)

View File

@ -143,6 +143,7 @@ class VitMatteModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = VitMatteModelTester(self)

View File

@ -154,6 +154,7 @@ class VitPoseModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = VitPoseModelTester(self)

View File

@ -18,7 +18,7 @@ import inspect
import unittest
from transformers import VitPoseBackboneConfig
from transformers.testing_utils import require_torch
from transformers.testing_utils import require_torch, torch_device
from transformers.utils import is_torch_available, is_vision_available
from ...test_backbone_common import BackboneTesterMixin
@ -27,6 +27,8 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers import VitPoseBackbone
@ -129,6 +131,7 @@ class VitPoseBackboneModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = VitPoseBackboneModelTester(self)
@ -187,6 +190,17 @@ class VitPoseBackboneModelTest(ModelTesterMixin, unittest.TestCase):
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_torch_export(self):
# Dense architecture
super().test_torch_export()
# MOE architecture
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_experts = 2
config.part_features = config.hidden_size // config.num_experts
inputs_dict["dataset_index"] = torch.tensor([0] * self.model_tester.batch_size, device=torch_device)
super().test_torch_export(config=config, inputs_dict=inputs_dict)
@require_torch
class VitPoseBackboneTest(unittest.TestCase, BackboneTesterMixin):

View File

@ -175,6 +175,7 @@ class VivitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_torchscript = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = VivitModelTester(self)

View File

@ -178,6 +178,7 @@ class YolosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_resize_embeddings = False
test_head_masking = False
test_torchscript = False
test_torch_exportable = True
# special case for head model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -147,6 +147,7 @@ class ZoeDepthModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = ZoeDepthModelTester(self)

View File

@ -86,6 +86,7 @@ from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_greater_or_equal,
require_torch_multi_accelerator,
require_torch_multi_gpu,
require_torch_sdpa,
@ -221,6 +222,7 @@ class ModelTesterMixin:
test_mismatched_shapes = True
test_missing_keys = True
test_model_parallel = False
test_torch_exportable = False
# Used in `check_training_gradient_checkpointing` to NOT check all params having gradient (e.g. for some MOE models)
test_all_params_have_gradient = True
is_encoder_decoder = False
@ -4865,6 +4867,72 @@ class ModelTesterMixin:
# Assert the last tokens are actually the same (except for the natural fluctuation due to order of FP ops)
torch.testing.assert_close(all_logits[:, -1:, :], last_token_logits, rtol=1e-5, atol=1e-5)
@slow
@require_torch_greater_or_equal("2.5")
def test_torch_export(self, config=None, inputs_dict=None, tolerance=1e-4):
"""
Test if model can be exported with torch.export.export()
Args:
config (PretrainedConfig):
Config to use for the model, if None, use default config from model_tester
inputs_dict (dict):
Inputs to use for the model, if None, use default inputs from model_tester
tolerance (float):
`atol` for torch.allclose(), defined in signature for test overriding
"""
if not self.test_torch_exportable:
self.skipTest(reason="test_torch_exportable=False for this model.")
def recursively_check(eager_outputs, exported_outputs):
is_tested = False
if isinstance(eager_outputs, torch.Tensor):
torch.testing.assert_close(eager_outputs, exported_outputs, atol=tolerance, rtol=tolerance)
return True
elif isinstance(eager_outputs, (tuple, list)):
for eager_output, exported_output in zip(eager_outputs, exported_outputs):
is_tested = is_tested or recursively_check(eager_output, exported_output)
return is_tested
elif isinstance(eager_outputs, dict):
for key in eager_outputs:
is_tested = is_tested or recursively_check(eager_outputs[key], exported_outputs[key])
return is_tested
return is_tested
default_config, default_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config = config or default_config
inputs_dict = inputs_dict or default_inputs_dict
for model_class in self.all_model_classes:
if model_class.__name__.endswith("ForPreTraining"):
continue
with self.subTest(model_class.__name__):
model = model_class(config).eval().to(torch_device)
# Export model
exported_model = torch.export.export(
model,
args=(),
kwargs=inputs_dict,
strict=True,
)
# Run exported model and eager model
with torch.no_grad():
# set seed in case anything is not deterministic in model (e.g. vit_mae noise)
torch.manual_seed(1234)
eager_outputs = model(**inputs_dict)
torch.manual_seed(1234)
exported_outputs = exported_model.module().forward(**inputs_dict)
# Check if outputs are close:
# is_tested is a boolean flag idicating if we comapre any outputs,
# e.g. there might be a situation when outputs are empty list, then is_tested will be False.
# In case of outputs are different the error will be rasied in `recursively_check` function.
is_tested = recursively_check(eager_outputs, exported_outputs)
self.assertTrue(is_tested, msg=f"No outputs were compared for {model_class.__name__}")
@require_torch_gpu
def test_flex_attention_with_grads(self):
for model_class in self.all_model_classes: