Fix ONNX exports for Optimum compatible models (#31311)

* fixed models

* format with bumped ruff version on my local

* fix copies

* add tracing checks

* format

* Update src/transformers/utils/generic.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* format

* style fix

* Update modeling_mobilevit.py

* add docstring and change name

* Update __init__.py

* Update __init__.py

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Merve Noyan 2024-06-27 12:46:36 +03:00 committed by GitHub
parent dc76e9fa7f
commit c9f191a0b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 72 additions and 17 deletions

View File

@ -37,6 +37,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig
@ -590,8 +591,10 @@ class ClapAudioLayer(nn.Module):
def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(input_resolution)
self.shift_size = torch_int(0)
self.window_size = (
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
)
def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0:

View File

@ -35,6 +35,7 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
torch_int,
)
from .configuration_donut_swin import DonutSwinConfig
@ -562,8 +563,10 @@ class DonutSwinLayer(nn.Module):
def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(input_resolution)
self.shift_size = torch_int(0)
self.window_size = (
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
)
def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0:

View File

@ -39,7 +39,7 @@ from ...file_utils import (
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, logging
from ...utils import ModelOutput, logging, torch_int
from ...utils.backbone_utils import load_backbone
from .configuration_dpt import DPTConfig
@ -226,7 +226,7 @@ class DPTViTEmbeddings(nn.Module):
posemb_tok = posemb[:, :start_index]
posemb_grid = posemb[0, start_index:]
old_grid_size = int(math.sqrt(len(posemb_grid)))
old_grid_size = torch_int(posemb_grid.size(0) ** 0.5)
posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")

View File

@ -33,7 +33,13 @@ from ...modeling_outputs import (
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_float,
)
from .configuration_imagegpt import ImageGPTConfig
@ -229,7 +235,7 @@ class ImageGPTAttention(nn.Module):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
attn_weights = attn_weights / torch_float(value.size(-1) ** 0.5)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:

View File

@ -33,7 +33,13 @@ from ...modeling_outputs import (
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_layoutlmv3 import LayoutLMv3Config
@ -910,8 +916,8 @@ class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
patch_height = patch_width = None
if pixel_values is not None:
patch_height, patch_width = (
int(pixel_values.shape[2] / self.config.patch_size),
int(pixel_values.shape[3] / self.config.patch_size),
torch_int(pixel_values.shape[2] / self.config.patch_size),
torch_int(pixel_values.shape[3] / self.config.patch_size),
)
visual_embeddings = self.forward_image(pixel_values)
visual_attention_mask = torch.ones(

View File

@ -39,6 +39,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_mobilevit import MobileViTConfig
@ -437,8 +438,16 @@ class MobileViTLayer(nn.Module):
batch_size, channels, orig_height, orig_width = features.shape
new_height = int(math.ceil(orig_height / patch_height) * patch_height)
new_width = int(math.ceil(orig_width / patch_width) * patch_width)
new_height = (
torch_int(torch.ceil(orig_height / patch_height) * patch_height)
if torch.jit.is_tracing()
else int(math.ceil(orig_height / patch_height) * patch_height)
)
new_width = (
torch_int(torch.ceil(orig_width / patch_width) * patch_width)
if torch.jit.is_tracing()
else int(math.ceil(orig_width / patch_width) * patch_width)
)
interpolate = False
if new_width != orig_width or new_height != orig_height:

View File

@ -15,7 +15,6 @@
"""PyTorch SAM model."""
import collections
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
@ -232,7 +231,7 @@ class SamAttention(nn.Module):
# SamAttention
_, _, _, c_per_head = query.shape
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = attn / (c_per_head**0.5)
attn = torch.softmax(attn, dim=-1)
if attention_similarity is not None:

View File

@ -36,6 +36,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from ...utils.backbone_utils import BackboneMixin
from .configuration_swin import SwinConfig
@ -639,8 +640,10 @@ class SwinLayer(nn.Module):
def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(input_resolution)
self.shift_size = torch_int(0)
self.window_size = (
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
)
def get_attn_mask(self, height, width, dtype, device):
if self.shift_size > 0:

View File

@ -60,6 +60,8 @@ from .generic import (
tensor_size,
to_numpy,
to_py_obj,
torch_float,
torch_int,
transpose,
working_or_temp_dir,
)

View File

@ -753,6 +753,30 @@ def infer_framework(model_class):
raise TypeError(f"Could not infer framework from class {model_class}.")
def torch_int(x):
"""
Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int.
"""
if not is_torch_available():
return int(x)
import torch
return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
def torch_float(x):
"""
Casts an input to a torch float32 tensor if we are in a tracing context, otherwise to a Python float.
"""
if not is_torch_available():
return int(x)
import torch
return x.to(torch.float32) if torch.jit.is_tracing() else int(x)
def filter_out_non_signature_kwargs(extra: Optional[list] = None):
"""
Decorator to filter out named arguments that are not in the function signature.