mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add common test for torch.export
and fix some vision models (#35124)
* Add is_torch_greater_or_equal test decorator * Add common test for torch.export * Fix bit * Fix focalnet * Fix imagegpt * Fix seggpt * Fix swin2sr * Enable torch.export test for vision models * Enable test for video models * Remove json * Enable for hiera * Enable for ijepa * Fix detr * Fic conditional_detr * Fix maskformer * Enable test maskformer * Fix test for deformable detr * Fix custom kernels for export in rt-detr and deformable-detr * Enable test for all DPT * Remove custom test for deformable detr * Simplify test to use only kwargs for export * Add comment * Move compile_compatible_method_lru_cache to utils * Fix beit export * Fix deformable detr * Fix copies data2vec<->beit * Fix typos, update test to work with dict * Add seed to the test * Enable test for vit_mae * Fix beit tests * [run-slow] beit, bit, conditional_detr, data2vec, deformable_detr, detr, focalnet, imagegpt, maskformer, rt_detr, seggpt, swin2sr * Add vitpose test * Add textnet test * Add dinov2 with registers * Update tests/test_modeling_common.py * Switch to torch.testing.assert_close * Fix masformer * Remove save-load from test * Add dab_detr * Add depth_pro * Fix and test RT-DETRv2 * Fix dab_detr
This commit is contained in:
parent
1779f5180e
commit
f42d46ccb4
@ -34,7 +34,7 @@ from ...modeling_outputs import (
|
||||
SemanticSegmenterOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -297,10 +297,9 @@ class BeitSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
if window_size:
|
||||
self.has_relative_position_bias = bool(window_size)
|
||||
if self.has_relative_position_bias:
|
||||
self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
|
||||
else:
|
||||
self.relative_position_bias = None
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
@ -312,7 +311,7 @@ class BeitSelfAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
||||
relative_position_bias: Optional[torch.Tensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
@ -328,7 +327,7 @@ class BeitSelfAttention(nn.Module):
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Add relative position bias if present.
|
||||
if self.relative_position_bias is not None:
|
||||
if self.has_relative_position_bias:
|
||||
height, width = resolution
|
||||
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
||||
attention_scores = attention_scores + self.relative_position_bias(
|
||||
@ -367,7 +366,7 @@ class BeitSdpaSelfAttention(BeitSelfAttention):
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
||||
relative_position_bias: Optional[torch.Tensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
@ -393,7 +392,7 @@ class BeitSdpaSelfAttention(BeitSelfAttention):
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
attn_bias = None
|
||||
if self.relative_position_bias is not None:
|
||||
if self.has_relative_position_bias:
|
||||
height, width = resolution
|
||||
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
||||
attn_bias = self.relative_position_bias(
|
||||
@ -477,7 +476,7 @@ class BeitAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
||||
relative_position_bias: Optional[torch.Tensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
@ -546,7 +545,7 @@ class BeitLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
||||
relative_position_bias: Optional[torch.Tensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
@ -595,8 +594,7 @@ class BeitRelativePositionBias(nn.Module):
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
self.relative_position_indices = {}
|
||||
|
||||
@compile_compatible_method_lru_cache(maxsize=10)
|
||||
def generate_relative_position_index(self, window_size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
This method creates the relative position index, modified to support arbitrary window sizes,
|
||||
@ -648,11 +646,9 @@ class BeitRelativePositionBias(nn.Module):
|
||||
[new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
|
||||
)
|
||||
|
||||
key = window_size
|
||||
if key not in self.relative_position_indices.keys():
|
||||
self.relative_position_indices[key] = self.generate_relative_position_index(window_size)
|
||||
relative_position_index = self.generate_relative_position_index(window_size)
|
||||
relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)]
|
||||
|
||||
relative_position_bias = new_relative_position_bias_table[self.relative_position_indices[key].view(-1)]
|
||||
# patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
|
||||
relative_position_bias = relative_position_bias.view(
|
||||
window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
|
||||
@ -675,10 +671,9 @@ class BeitEncoder(nn.Module):
|
||||
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
if config.use_shared_relative_position_bias:
|
||||
self.has_relative_position_bias = config.use_shared_relative_position_bias
|
||||
if self.has_relative_position_bias:
|
||||
self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
|
||||
else:
|
||||
self.relative_position_bias = None
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
||||
@ -701,7 +696,7 @@ class BeitEncoder(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
resolution: Optional[Tuple[int, int]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@ -711,6 +706,15 @@ class BeitEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.has_relative_position_bias:
|
||||
height, width = resolution
|
||||
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
||||
relative_position_bias = self.relative_position_bias(
|
||||
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
|
||||
)
|
||||
else:
|
||||
relative_position_bias = None
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
@ -719,17 +723,11 @@ class BeitEncoder(nn.Module):
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
relative_position_bias,
|
||||
interpolate_pos_encoding,
|
||||
resolution,
|
||||
)
|
||||
else:
|
||||
height, width = resolution
|
||||
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
||||
relative_position_bias = (
|
||||
self.relative_position_bias(
|
||||
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
|
||||
)
|
||||
if self.relative_position_bias is not None
|
||||
else None
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
|
@ -192,7 +192,7 @@ class DynamicPad2d(nn.Module):
|
||||
|
||||
self.compute_padding = compute_padding
|
||||
|
||||
def __call__(self, input):
|
||||
def forward(self, input):
|
||||
# Get width and height
|
||||
input_height, input_width = input.size()[-2:]
|
||||
|
||||
|
@ -1735,7 +1735,11 @@ class ConditionalDetrForObjectDetection(ConditionalDetrPreTrainedModel):
|
||||
# class logits + predicted bounding boxes
|
||||
logits = self.class_labels_classifier(sequence_output)
|
||||
|
||||
reference = outputs.reference_points if return_dict else outputs[-1]
|
||||
# Index [-2] is valid only if `output_attentions` and `output_hidden_states`
|
||||
# are not specified, otherwise it will be another index which is hard to determine.
|
||||
# Leave it as is, because it's not a common case to use
|
||||
# return_dict=False + output_attentions=True / output_hidden_states=True
|
||||
reference = outputs.reference_points if return_dict else outputs[-2]
|
||||
reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
|
||||
|
||||
hs = sequence_output
|
||||
@ -2105,7 +2109,7 @@ class ConditionalDetrMHAttentionMap(nn.Module):
|
||||
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
|
||||
|
||||
if mask is not None:
|
||||
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
|
||||
weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
|
||||
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
|
||||
weights = self.dropout(weights)
|
||||
return weights
|
||||
|
@ -1537,7 +1537,7 @@ class DabDetrMHAttentionMap(nn.Module):
|
||||
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
|
||||
|
||||
if mask is not None:
|
||||
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
|
||||
weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
|
||||
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
|
||||
weights = self.dropout(weights)
|
||||
return weights
|
||||
|
@ -32,7 +32,7 @@ from ...modeling_outputs import (
|
||||
SemanticSegmenterOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...pytorch_utils import compile_compatible_method_lru_cache, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -298,10 +298,9 @@ class Data2VecVisionSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
if window_size:
|
||||
self.has_relative_position_bias = bool(window_size)
|
||||
if self.has_relative_position_bias:
|
||||
self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
|
||||
else:
|
||||
self.relative_position_bias = None
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
@ -313,7 +312,7 @@ class Data2VecVisionSelfAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
||||
relative_position_bias: Optional[torch.Tensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
@ -329,7 +328,7 @@ class Data2VecVisionSelfAttention(nn.Module):
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Add relative position bias if present.
|
||||
if self.relative_position_bias is not None:
|
||||
if self.has_relative_position_bias:
|
||||
height, width = resolution
|
||||
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
||||
attention_scores = attention_scores + self.relative_position_bias(
|
||||
@ -369,7 +368,7 @@ class Data2VecVisionSdpaSelfAttention(Data2VecVisionSelfAttention):
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
||||
relative_position_bias: Optional[torch.Tensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
@ -395,7 +394,7 @@ class Data2VecVisionSdpaSelfAttention(Data2VecVisionSelfAttention):
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
attn_bias = None
|
||||
if self.relative_position_bias is not None:
|
||||
if self.has_relative_position_bias:
|
||||
height, width = resolution
|
||||
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
||||
attn_bias = self.relative_position_bias(
|
||||
@ -557,7 +556,7 @@ class Data2VecVisionLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
||||
relative_position_bias: Optional[torch.Tensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
@ -607,8 +606,7 @@ class Data2VecVisionRelativePositionBias(nn.Module):
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
self.relative_position_indices = {}
|
||||
|
||||
@compile_compatible_method_lru_cache(maxsize=10)
|
||||
def generate_relative_position_index(self, window_size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
This method creates the relative position index, modified to support arbitrary window sizes,
|
||||
@ -660,11 +658,9 @@ class Data2VecVisionRelativePositionBias(nn.Module):
|
||||
[new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
|
||||
)
|
||||
|
||||
key = window_size
|
||||
if key not in self.relative_position_indices.keys():
|
||||
self.relative_position_indices[key] = self.generate_relative_position_index(window_size)
|
||||
relative_position_index = self.generate_relative_position_index(window_size)
|
||||
relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)]
|
||||
|
||||
relative_position_bias = new_relative_position_bias_table[self.relative_position_indices[key].view(-1)]
|
||||
# patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
|
||||
relative_position_bias = relative_position_bias.view(
|
||||
window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
|
||||
@ -688,10 +684,9 @@ class Data2VecVisionEncoder(nn.Module):
|
||||
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
if config.use_shared_relative_position_bias:
|
||||
self.has_relative_position_bias = config.use_shared_relative_position_bias
|
||||
if self.has_relative_position_bias:
|
||||
self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
|
||||
else:
|
||||
self.relative_position_bias = None
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
||||
@ -714,7 +709,7 @@ class Data2VecVisionEncoder(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
resolution: Optional[Tuple[int, int]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@ -724,6 +719,15 @@ class Data2VecVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.has_relative_position_bias:
|
||||
height, width = resolution
|
||||
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
||||
relative_position_bias = self.relative_position_bias(
|
||||
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
|
||||
)
|
||||
else:
|
||||
relative_position_bias = None
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
@ -732,17 +736,11 @@ class Data2VecVisionEncoder(nn.Module):
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
relative_position_bias,
|
||||
interpolate_pos_encoding,
|
||||
resolution,
|
||||
)
|
||||
else:
|
||||
height, width = resolution
|
||||
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
||||
relative_position_bias = (
|
||||
self.relative_position_bias(
|
||||
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
|
||||
)
|
||||
if self.relative_position_bias is not None
|
||||
else None
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
|
@ -40,6 +40,7 @@ from ...utils import (
|
||||
is_ninja_available,
|
||||
is_timm_available,
|
||||
is_torch_cuda_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
@ -705,7 +706,7 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||
|
||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
|
||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None or is_torchdynamo_compiling():
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(
|
||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
||||
@ -1606,7 +1607,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
Args:
|
||||
enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder.
|
||||
padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`.
|
||||
spatial_shapes (Tensor[num_feature_levels, 2]): Spatial shapes of the feature maps.
|
||||
spatial_shapes (List[Tuple[int, int]]): Spatial shapes of the feature maps.
|
||||
|
||||
Returns:
|
||||
`tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction.
|
||||
@ -1786,7 +1787,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
enc_outputs_coord_logits = None
|
||||
if self.config.two_stage:
|
||||
object_query_embedding, output_proposals = self.gen_encoder_output_proposals(
|
||||
encoder_outputs[0], ~mask_flatten, spatial_shapes
|
||||
encoder_outputs[0], ~mask_flatten, spatial_shapes_list
|
||||
)
|
||||
|
||||
# hack implementation for two-stage Deformable DETR
|
||||
|
@ -1801,7 +1801,7 @@ class DetrMHAttentionMap(nn.Module):
|
||||
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)
|
||||
|
||||
if mask is not None:
|
||||
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
|
||||
weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
|
||||
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
|
||||
weights = self.dropout(weights)
|
||||
return weights
|
||||
|
@ -358,23 +358,23 @@ class FocalNetModulation(nn.Module):
|
||||
|
||||
# pre linear projection
|
||||
x = self.projection_in(hidden_state).permute(0, 3, 1, 2).contiguous()
|
||||
q, ctx, self.gates = torch.split(x, (num_channels, num_channels, self.focal_level + 1), 1)
|
||||
q, ctx, gates = torch.split(x, (num_channels, num_channels, self.focal_level + 1), 1)
|
||||
|
||||
# context aggreation
|
||||
ctx_all = 0
|
||||
for level in range(self.focal_level):
|
||||
ctx = self.focal_layers[level](ctx)
|
||||
ctx_all = ctx_all + ctx * self.gates[:, level : level + 1]
|
||||
ctx_all = ctx_all + ctx * gates[:, level : level + 1]
|
||||
ctx_global = self.activation(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
|
||||
ctx_all = ctx_all + ctx_global * self.gates[:, self.focal_level :]
|
||||
ctx_all = ctx_all + ctx_global * gates[:, self.focal_level :]
|
||||
|
||||
# normalize context
|
||||
if self.normalize_modulator:
|
||||
ctx_all = ctx_all / (self.focal_level + 1)
|
||||
|
||||
# focal modulation
|
||||
self.modulator = self.projection_context(ctx_all)
|
||||
x_out = q * self.modulator
|
||||
modulator = self.projection_context(ctx_all)
|
||||
x_out = q * modulator
|
||||
x_out = x_out.permute(0, 2, 3, 1).contiguous()
|
||||
if self.use_post_layernorm_in_modulation:
|
||||
x_out = self.layernorm(x_out)
|
||||
|
@ -164,13 +164,11 @@ class ImageGPTLayerNorm(nn.Module):
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.Tensor(hidden_size))
|
||||
|
||||
def forward(self, tensor: torch.Tensor) -> tuple:
|
||||
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
# input is not mean centered
|
||||
return (
|
||||
tensor
|
||||
/ torch.sqrt(torch.mean(torch.square(tensor), axis=-1, keepdim=True) + self.eps)
|
||||
* self.weight.data[..., :]
|
||||
)
|
||||
tensor = tensor / torch.sqrt(torch.mean(torch.square(tensor), axis=-1, keepdim=True) + self.eps)
|
||||
tensor = tensor * self.weight
|
||||
return tensor
|
||||
|
||||
|
||||
class ImageGPTAttention(nn.Module):
|
||||
|
@ -1424,7 +1424,11 @@ class MaskFormerTransformerModule(nn.Module):
|
||||
# repeat the queries "q c -> b q c"
|
||||
batch_size = image_features.shape[0]
|
||||
queries_embeddings = self.queries_embedder.weight.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=True)
|
||||
inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=self.training)
|
||||
|
||||
# torch.export.export does no support requires_grad
|
||||
if self.training:
|
||||
inputs_embeds.requires_grad_(True)
|
||||
|
||||
batch_size, num_channels, height, width = image_features.shape
|
||||
# rearrange both image_features and object_queries "b c h w -> b (h w) c"
|
||||
|
@ -18,7 +18,7 @@ import math
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache, partial, wraps
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
@ -32,12 +32,14 @@ from ...activations import ACT2CLS, ACT2FN
|
||||
from ...image_transforms import center_to_corners_format, corners_to_center_format
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import compile_compatible_method_lru_cache
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_ninja_available,
|
||||
is_torch_cuda_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -870,7 +872,7 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||
|
||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
|
||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None or is_torchdynamo_compiling():
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(
|
||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
||||
@ -1590,27 +1592,6 @@ class RTDetrDecoder(RTDetrPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
def compile_compatible_lru_cache(*lru_args, **lru_kwargs):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not torch.compiler.is_compiling():
|
||||
# Cache the function only if the model is not being compiled
|
||||
# check if the function is already cached, otherwise create it
|
||||
if not hasattr(self, f"_cached_{func.__name__}"):
|
||||
self.__setattr__(
|
||||
f"_cached_{func.__name__}", lru_cache(*lru_args, **lru_kwargs)(func.__get__(self))
|
||||
)
|
||||
return self.__getattribute__(f"_cached_{func.__name__}")(*args, **kwargs)
|
||||
else:
|
||||
# Otherwise, just call the original function
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||||
class RTDetrMLPPredictionHead(nn.Module):
|
||||
"""
|
||||
@ -1728,7 +1709,7 @@ class RTDetrModel(RTDetrPreTrainedModel):
|
||||
for param in self.backbone.parameters():
|
||||
param.requires_grad_(True)
|
||||
|
||||
@compile_compatible_lru_cache(maxsize=32)
|
||||
@compile_compatible_method_lru_cache(maxsize=32)
|
||||
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
|
||||
if spatial_shapes is None:
|
||||
spatial_shapes = [
|
||||
|
@ -22,7 +22,7 @@ import math
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache, partial, wraps
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
@ -34,12 +34,14 @@ from ...activations import ACT2CLS, ACT2FN
|
||||
from ...image_transforms import center_to_corners_format, corners_to_center_format
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import compile_compatible_method_lru_cache
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_ninja_available,
|
||||
is_torch_cuda_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -97,7 +99,7 @@ def multi_scale_deformable_attention_v2(
|
||||
value_list = (
|
||||
value.permute(0, 2, 3, 1)
|
||||
.flatten(0, 1)
|
||||
.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=-1)
|
||||
.split([height * width for height, width in value_spatial_shapes], dim=-1)
|
||||
)
|
||||
# sampling_offsets [8, 480, 8, 12, 2]
|
||||
if method == "default":
|
||||
@ -226,9 +228,9 @@ class RTDetrV2MultiscaleDeformableAttention(nn.Module):
|
||||
position_embeddings: Optional[torch.Tensor] = None,
|
||||
reference_points=None,
|
||||
spatial_shapes=None,
|
||||
spatial_shapes_list=None,
|
||||
level_start_index=None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Process inputs up to sampling locations calculation using parent class logic
|
||||
if position_embeddings is not None:
|
||||
@ -236,7 +238,7 @@ class RTDetrV2MultiscaleDeformableAttention(nn.Module):
|
||||
|
||||
batch_size, num_queries, _ = hidden_states.shape
|
||||
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
||||
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
||||
if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
||||
raise ValueError(
|
||||
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
||||
)
|
||||
@ -272,7 +274,7 @@ class RTDetrV2MultiscaleDeformableAttention(nn.Module):
|
||||
|
||||
# V2-specific attention implementation choice
|
||||
output = multi_scale_deformable_attention_v2(
|
||||
value, spatial_shapes, sampling_locations, attention_weights, self.n_points_list, self.method
|
||||
value, spatial_shapes_list, sampling_locations, attention_weights, self.n_points_list, self.method
|
||||
)
|
||||
|
||||
output = self.output_proj(output)
|
||||
@ -1329,27 +1331,6 @@ RTDetrV2_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
def compile_compatible_lru_cache(*lru_args, **lru_kwargs):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not torch.compiler.is_compiling():
|
||||
# Cache the function only if the model is not being compiled
|
||||
# check if the function is already cached, otherwise create it
|
||||
if not hasattr(self, f"_cached_{func.__name__}"):
|
||||
self.__setattr__(
|
||||
f"_cached_{func.__name__}", lru_cache(*lru_args, **lru_kwargs)(func.__get__(self))
|
||||
)
|
||||
return self.__getattribute__(f"_cached_{func.__name__}")(*args, **kwargs)
|
||||
else:
|
||||
# Otherwise, just call the original function
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _get_clones(partial_module, N):
|
||||
return nn.ModuleList([partial_module() for i in range(N)])
|
||||
|
||||
@ -1669,7 +1650,7 @@ class RTDetrV2Model(RTDetrV2PreTrainedModel):
|
||||
for param in self.backbone.parameters():
|
||||
param.requires_grad_(True)
|
||||
|
||||
@compile_compatible_lru_cache(maxsize=32)
|
||||
@compile_compatible_method_lru_cache(maxsize=32)
|
||||
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
|
||||
if spatial_shapes is None:
|
||||
spatial_shapes = [
|
||||
|
@ -20,7 +20,7 @@ import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils import is_torchdynamo_compiling, logging
|
||||
from ...utils.backbone_utils import (
|
||||
verify_backbone_config_arguments,
|
||||
)
|
||||
@ -404,7 +404,7 @@ def multi_scale_deformable_attention_v2(
|
||||
value_list = (
|
||||
value.permute(0, 2, 3, 1)
|
||||
.flatten(0, 1)
|
||||
.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=-1)
|
||||
.split([height * width for height, width in value_spatial_shapes], dim=-1)
|
||||
)
|
||||
# sampling_offsets [8, 480, 8, 12, 2]
|
||||
if method == "default":
|
||||
@ -497,9 +497,9 @@ class RTDetrV2MultiscaleDeformableAttention(RTDetrMultiscaleDeformableAttention)
|
||||
position_embeddings: Optional[torch.Tensor] = None,
|
||||
reference_points=None,
|
||||
spatial_shapes=None,
|
||||
spatial_shapes_list=None,
|
||||
level_start_index=None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Process inputs up to sampling locations calculation using parent class logic
|
||||
if position_embeddings is not None:
|
||||
@ -507,7 +507,7 @@ class RTDetrV2MultiscaleDeformableAttention(RTDetrMultiscaleDeformableAttention)
|
||||
|
||||
batch_size, num_queries, _ = hidden_states.shape
|
||||
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
||||
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
||||
if not is_torchdynamo_compiling() and (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
||||
raise ValueError(
|
||||
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
||||
)
|
||||
@ -543,7 +543,7 @@ class RTDetrV2MultiscaleDeformableAttention(RTDetrMultiscaleDeformableAttention)
|
||||
|
||||
# V2-specific attention implementation choice
|
||||
output = multi_scale_deformable_attention_v2(
|
||||
value, spatial_shapes, sampling_locations, attention_weights, self.n_points_list, self.method
|
||||
value, spatial_shapes_list, sampling_locations, attention_weights, self.n_points_list, self.method
|
||||
)
|
||||
|
||||
output = self.output_proj(output)
|
||||
|
@ -817,8 +817,11 @@ class SegGptModel(SegGptPreTrainedModel):
|
||||
# and reconstructed together (In-Context Painting).
|
||||
if bool_masked_pos is None:
|
||||
num_patches = self.embeddings.patch_embeddings.num_patches
|
||||
bool_masked_pos = torch.zeros(num_patches, dtype=torch.bool).to(pixel_values.device)
|
||||
bool_masked_pos[num_patches // 2 :] = 1
|
||||
bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device)
|
||||
bool_masked_pos_ones = torch.ones(
|
||||
num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device
|
||||
)
|
||||
bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones])
|
||||
bool_masked_pos = bool_masked_pos.unsqueeze(0)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
@ -975,8 +978,11 @@ class SegGptForImageSegmentation(SegGptPreTrainedModel):
|
||||
|
||||
if bool_masked_pos is None:
|
||||
num_patches = self.model.embeddings.patch_embeddings.num_patches
|
||||
bool_masked_pos = torch.zeros(num_patches, dtype=torch.bool).to(pixel_values.device)
|
||||
bool_masked_pos[num_patches // 2 :] = 1
|
||||
bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device)
|
||||
bool_masked_pos_ones = torch.ones(
|
||||
num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device
|
||||
)
|
||||
bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones])
|
||||
bool_masked_pos = bool_masked_pos.unsqueeze(0)
|
||||
|
||||
outputs = self.model(
|
||||
|
@ -813,10 +813,11 @@ class Swin2SRModel(Swin2SRPreTrainedModel):
|
||||
self.config = config
|
||||
|
||||
if config.num_channels == 3 and config.num_channels_out == 3:
|
||||
rgb_mean = (0.4488, 0.4371, 0.4040)
|
||||
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
||||
mean = torch.tensor([0.4488, 0.4371, 0.4040]).view(1, 3, 1, 1)
|
||||
else:
|
||||
self.mean = torch.zeros(1, 1, 1, 1)
|
||||
mean = torch.zeros(1, 1, 1, 1)
|
||||
self.register_buffer("mean", mean, persistent=False)
|
||||
|
||||
self.img_range = config.img_range
|
||||
|
||||
self.first_convolution = nn.Conv2d(config.num_channels, config.embed_dim, 3, 1, 1)
|
||||
@ -851,8 +852,8 @@ class Swin2SRModel(Swin2SRPreTrainedModel):
|
||||
pixel_values = nn.functional.pad(pixel_values, (0, modulo_pad_width, 0, modulo_pad_height), "reflect")
|
||||
|
||||
# 2. normalize
|
||||
self.mean = self.mean.type_as(pixel_values)
|
||||
pixel_values = (pixel_values - self.mean) * self.img_range
|
||||
mean = self.mean.type_as(pixel_values)
|
||||
pixel_values = (pixel_values - mean) * self.img_range
|
||||
|
||||
return pixel_values
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Callable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -21,7 +22,7 @@ from packaging import version
|
||||
from safetensors.torch import storage_ptr, storage_size
|
||||
from torch import nn
|
||||
|
||||
from .utils import is_torch_greater_or_equal, is_torch_xla_available, logging
|
||||
from .utils import is_torch_greater_or_equal, is_torch_xla_available, is_torchdynamo_compiling, logging
|
||||
|
||||
|
||||
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
||||
@ -364,3 +365,29 @@ def translate_to_torch_parallel_style(style: str):
|
||||
return RowwiseParallel(input_layouts=Replicate())
|
||||
else:
|
||||
raise ValueError(f"Unsupported parallel style value: {style}")
|
||||
|
||||
|
||||
def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
|
||||
"""
|
||||
LRU cache decorator from standard functools library, but with a workaround to disable
|
||||
caching when torchdynamo is compiling. Expected to work with class methods.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not is_torchdynamo_compiling():
|
||||
# Cache the function only if the model is not being compiled
|
||||
# check if the function is already cached, otherwise create it
|
||||
if not hasattr(self, f"_cached_{func.__name__}"):
|
||||
self.__setattr__(
|
||||
f"_cached_{func.__name__}", lru_cache(*lru_args, **lru_kwargs)(func.__get__(self))
|
||||
)
|
||||
return self.__getattribute__(f"_cached_{func.__name__}")(*args, **kwargs)
|
||||
else:
|
||||
# Otherwise, just call the original function
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
@ -135,6 +135,7 @@ from .utils import (
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_deterministic,
|
||||
is_torch_fp16_available_on_device,
|
||||
is_torch_greater_or_equal,
|
||||
is_torch_neuroncore_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_sdpa_available,
|
||||
@ -556,6 +557,21 @@ def require_torch(test_case):
|
||||
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_torch_greater_or_equal(version: str):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch version >= `version`.
|
||||
|
||||
These tests are skipped when PyTorch version is less than `version`.
|
||||
"""
|
||||
|
||||
def decorator(test_case):
|
||||
return unittest.skipUnless(is_torch_greater_or_equal(version), f"test requires PyTorch version >= {version}")(
|
||||
test_case
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_flash_attn(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires Flash Attention.
|
||||
|
@ -271,6 +271,7 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BeitModelTester(self)
|
||||
@ -292,6 +293,10 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BEiT can't compile dynamic")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
def test_model_get_set_embeddings(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -764,13 +769,6 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
inputs = processor(images=image, return_tensors="pt", size={"height": 480, "width": 480})
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# with interpolate_pos_encoding being False an exception should be raised with higher resolution
|
||||
# images than what the model supports.
|
||||
self.assertFalse(processor.do_center_crop)
|
||||
with torch.no_grad():
|
||||
with self.assertRaises(ValueError, msg="doesn't match model"):
|
||||
model(pixel_values, interpolate_pos_encoding=False)
|
||||
|
||||
# with interpolate_pos_encoding being True the model should process the higher resolution image
|
||||
# successfully and produce the expected output.
|
||||
with torch.no_grad():
|
||||
|
@ -170,6 +170,7 @@ class BitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BitModelTester(self)
|
||||
|
@ -194,6 +194,7 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
zero_init_hidden_state = True
|
||||
test_torch_exportable = True
|
||||
|
||||
# special case for head models
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -180,6 +180,7 @@ class ConvNextModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ConvNextModelTester(self)
|
||||
|
@ -188,6 +188,7 @@ class ConvNextV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ConvNextV2ModelTester(self)
|
||||
|
@ -159,6 +159,7 @@ class CvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = CvtModelTester(self)
|
||||
|
@ -197,6 +197,7 @@ class DabDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
zero_init_hidden_state = True
|
||||
test_torch_exportable = True
|
||||
|
||||
# special case for head models
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -200,6 +200,7 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
test_torch_exportable = True
|
||||
|
||||
# special case for head models
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -222,6 +222,7 @@ class DeiTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DeiTModelTester(self)
|
||||
|
@ -146,6 +146,7 @@ class DepthAnythingModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DepthAnythingModelTester(self)
|
||||
|
@ -212,6 +212,7 @@ class DepthProModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DepthProModelTester(self)
|
||||
|
@ -194,6 +194,7 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
zero_init_hidden_state = True
|
||||
test_torch_exportable = True
|
||||
|
||||
# special case for head models
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -216,6 +216,7 @@ class DinatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DinatModelTester(self)
|
||||
|
@ -212,6 +212,8 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
test_torch_exportable = True
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
Dinov2Model,
|
||||
|
@ -237,6 +237,7 @@ class Dinov2WithRegistersModelTest(ModelTesterMixin, PipelineTesterMixin, unitte
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Dinov2WithRegistersModelTester(self)
|
||||
|
@ -172,6 +172,7 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DPTModelTester(self)
|
||||
|
@ -140,6 +140,7 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DPTModelTester(self)
|
||||
|
@ -186,6 +186,7 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DPTModelTester(self)
|
||||
|
@ -139,6 +139,7 @@ class EfficientNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = EfficientNetModelTester(self)
|
||||
|
@ -247,6 +247,7 @@ class FocalNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FocalNetModelTester(self)
|
||||
|
@ -152,6 +152,7 @@ class GLPNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = GLPNModelTester(self)
|
||||
|
@ -250,6 +250,7 @@ class HieraModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = HieraModelTester(self)
|
||||
|
@ -207,6 +207,7 @@ class IJepaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = IJepaModelTester(self)
|
||||
|
@ -237,6 +237,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
else {}
|
||||
)
|
||||
test_missing_keys = False
|
||||
test_torch_exportable = True
|
||||
|
||||
# as ImageGPTForImageClassification isn't included in any auto mapping, we add labels here
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -205,6 +205,7 @@ class Mask2FormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Mask2FormerModelTester(self)
|
||||
|
@ -209,6 +209,7 @@ class MaskFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
zero_init_hidden_state = True
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MaskFormerModelTester(self)
|
||||
|
@ -181,6 +181,7 @@ class MaskFormerSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MaskFormerSwinModelTester(self)
|
||||
|
@ -154,6 +154,7 @@ class MobileNetV1ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MobileNetV1ModelTester(self)
|
||||
|
@ -205,6 +205,7 @@ class MobileNetV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MobileNetV2ModelTester(self)
|
||||
|
@ -198,6 +198,7 @@ class MobileViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MobileViTModelTester(self)
|
||||
|
@ -200,6 +200,7 @@ class MobileViTV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MobileViTV2ModelTester(self)
|
||||
|
@ -132,6 +132,7 @@ class PoolFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
test_resize_embeddings = False
|
||||
test_torchscript = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = PoolFormerModelTester(self)
|
||||
|
@ -166,6 +166,7 @@ class PvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_resize_embeddings = False
|
||||
test_torchscript = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = PvtModelTester(self)
|
||||
|
@ -202,6 +202,7 @@ class PvtV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_resize_embeddings = False
|
||||
test_torchscript = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = PvtV2ModelTester(self)
|
||||
|
@ -133,6 +133,7 @@ class RegNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RegNetModelTester(self)
|
||||
|
@ -178,6 +178,7 @@ class ResNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ResNetModelTester(self)
|
||||
|
@ -261,6 +261,7 @@ class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
test_torch_exportable = True
|
||||
|
||||
# special case for head models
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -259,6 +259,7 @@ class RTDetrV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
test_torch_exportable = True
|
||||
|
||||
# special case for head models
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -180,6 +180,7 @@ class SegformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SegformerModelTester(self)
|
||||
|
@ -172,6 +172,8 @@ class SegGptModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torchscript = False
|
||||
test_torch_exportable = True
|
||||
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": SegGptModel, "mask-generation": SegGptModel} if is_torch_available() else {}
|
||||
)
|
||||
|
@ -147,6 +147,7 @@ class SwiftFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SwiftFormerModelTester(self)
|
||||
|
@ -240,6 +240,7 @@ class SwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SwinModelTester(self)
|
||||
|
@ -172,6 +172,7 @@ class Swin2SRModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torchscript = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Swin2SRModelTester(self)
|
||||
|
@ -226,6 +226,7 @@ class Swinv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Swinv2ModelTester(self)
|
||||
|
@ -209,6 +209,7 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
zero_init_hidden_state = True
|
||||
test_torch_exportable = True
|
||||
|
||||
# special case for head models
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -217,6 +217,7 @@ class TextNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
has_attentions = False
|
||||
|
||||
def setUp(self):
|
||||
|
@ -167,6 +167,7 @@ class TimesformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TimesformerModelTester(self)
|
||||
|
@ -157,6 +157,7 @@ class UperNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
test_head_masking = False
|
||||
test_torchscript = False
|
||||
has_attentions = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = UperNetModelTester(self)
|
||||
|
@ -186,6 +186,7 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = VideoMAEModelTester(self)
|
||||
|
@ -207,6 +207,7 @@ class ViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ViTModelTester(self)
|
||||
|
@ -174,6 +174,7 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ViTMAEModelTester(self)
|
||||
|
@ -162,6 +162,7 @@ class ViTMSNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ViTMSNModelTester(self)
|
||||
|
@ -169,6 +169,7 @@ class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = VitDetModelTester(self)
|
||||
|
@ -143,6 +143,7 @@ class VitMatteModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = VitMatteModelTester(self)
|
||||
|
@ -154,6 +154,7 @@ class VitPoseModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = VitPoseModelTester(self)
|
||||
|
@ -18,7 +18,7 @@ import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import VitPoseBackboneConfig
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_backbone_common import BackboneTesterMixin
|
||||
@ -27,6 +27,8 @@ from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import VitPoseBackbone
|
||||
|
||||
|
||||
@ -129,6 +131,7 @@ class VitPoseBackboneModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = VitPoseBackboneModelTester(self)
|
||||
@ -187,6 +190,17 @@ class VitPoseBackboneModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_torch_export(self):
|
||||
# Dense architecture
|
||||
super().test_torch_export()
|
||||
|
||||
# MOE architecture
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_experts = 2
|
||||
config.part_features = config.hidden_size // config.num_experts
|
||||
inputs_dict["dataset_index"] = torch.tensor([0] * self.model_tester.batch_size, device=torch_device)
|
||||
super().test_torch_export(config=config, inputs_dict=inputs_dict)
|
||||
|
||||
|
||||
@require_torch
|
||||
class VitPoseBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||
|
@ -175,6 +175,7 @@ class VivitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = VivitModelTester(self)
|
||||
|
@ -178,6 +178,7 @@ class YolosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torchscript = False
|
||||
test_torch_exportable = True
|
||||
|
||||
# special case for head model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -147,6 +147,7 @@ class ZoeDepthModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ZoeDepthModelTester(self)
|
||||
|
@ -86,6 +86,7 @@ from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_greater_or_equal,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_sdpa,
|
||||
@ -221,6 +222,7 @@ class ModelTesterMixin:
|
||||
test_mismatched_shapes = True
|
||||
test_missing_keys = True
|
||||
test_model_parallel = False
|
||||
test_torch_exportable = False
|
||||
# Used in `check_training_gradient_checkpointing` to NOT check all params having gradient (e.g. for some MOE models)
|
||||
test_all_params_have_gradient = True
|
||||
is_encoder_decoder = False
|
||||
@ -4865,6 +4867,72 @@ class ModelTesterMixin:
|
||||
# Assert the last tokens are actually the same (except for the natural fluctuation due to order of FP ops)
|
||||
torch.testing.assert_close(all_logits[:, -1:, :], last_token_logits, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@slow
|
||||
@require_torch_greater_or_equal("2.5")
|
||||
def test_torch_export(self, config=None, inputs_dict=None, tolerance=1e-4):
|
||||
"""
|
||||
Test if model can be exported with torch.export.export()
|
||||
|
||||
Args:
|
||||
config (PretrainedConfig):
|
||||
Config to use for the model, if None, use default config from model_tester
|
||||
inputs_dict (dict):
|
||||
Inputs to use for the model, if None, use default inputs from model_tester
|
||||
tolerance (float):
|
||||
`atol` for torch.allclose(), defined in signature for test overriding
|
||||
"""
|
||||
if not self.test_torch_exportable:
|
||||
self.skipTest(reason="test_torch_exportable=False for this model.")
|
||||
|
||||
def recursively_check(eager_outputs, exported_outputs):
|
||||
is_tested = False
|
||||
if isinstance(eager_outputs, torch.Tensor):
|
||||
torch.testing.assert_close(eager_outputs, exported_outputs, atol=tolerance, rtol=tolerance)
|
||||
return True
|
||||
elif isinstance(eager_outputs, (tuple, list)):
|
||||
for eager_output, exported_output in zip(eager_outputs, exported_outputs):
|
||||
is_tested = is_tested or recursively_check(eager_output, exported_output)
|
||||
return is_tested
|
||||
elif isinstance(eager_outputs, dict):
|
||||
for key in eager_outputs:
|
||||
is_tested = is_tested or recursively_check(eager_outputs[key], exported_outputs[key])
|
||||
return is_tested
|
||||
return is_tested
|
||||
|
||||
default_config, default_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config = config or default_config
|
||||
inputs_dict = inputs_dict or default_inputs_dict
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__.endswith("ForPreTraining"):
|
||||
continue
|
||||
|
||||
with self.subTest(model_class.__name__):
|
||||
model = model_class(config).eval().to(torch_device)
|
||||
|
||||
# Export model
|
||||
exported_model = torch.export.export(
|
||||
model,
|
||||
args=(),
|
||||
kwargs=inputs_dict,
|
||||
strict=True,
|
||||
)
|
||||
|
||||
# Run exported model and eager model
|
||||
with torch.no_grad():
|
||||
# set seed in case anything is not deterministic in model (e.g. vit_mae noise)
|
||||
torch.manual_seed(1234)
|
||||
eager_outputs = model(**inputs_dict)
|
||||
torch.manual_seed(1234)
|
||||
exported_outputs = exported_model.module().forward(**inputs_dict)
|
||||
|
||||
# Check if outputs are close:
|
||||
# is_tested is a boolean flag idicating if we comapre any outputs,
|
||||
# e.g. there might be a situation when outputs are empty list, then is_tested will be False.
|
||||
# In case of outputs are different the error will be rasied in `recursively_check` function.
|
||||
is_tested = recursively_check(eager_outputs, exported_outputs)
|
||||
self.assertTrue(is_tested, msg=f"No outputs were compared for {model_class.__name__}")
|
||||
|
||||
@require_torch_gpu
|
||||
def test_flex_attention_with_grads(self):
|
||||
for model_class in self.all_model_classes:
|
||||
|
Loading…
Reference in New Issue
Block a user