mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix ONNX exports for Optimum compatible models (#31311)
* fixed models * format with bumped ruff version on my local * fix copies * add tracing checks * format * Update src/transformers/utils/generic.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * format * style fix * Update modeling_mobilevit.py * add docstring and change name * Update __init__.py * Update __init__.py --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
dc76e9fa7f
commit
c9f191a0b7
@ -37,6 +37,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig
|
||||
|
||||
@ -590,8 +591,10 @@ class ClapAudioLayer(nn.Module):
|
||||
def set_shift_and_window_size(self, input_resolution):
|
||||
if min(input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.window_size = min(input_resolution)
|
||||
self.shift_size = torch_int(0)
|
||||
self.window_size = (
|
||||
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
|
||||
)
|
||||
|
||||
def get_attn_mask(self, height, width, dtype, device):
|
||||
if self.shift_size > 0:
|
||||
|
@ -35,6 +35,7 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_donut_swin import DonutSwinConfig
|
||||
|
||||
@ -562,8 +563,10 @@ class DonutSwinLayer(nn.Module):
|
||||
def set_shift_and_window_size(self, input_resolution):
|
||||
if min(input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.window_size = min(input_resolution)
|
||||
self.shift_size = torch_int(0)
|
||||
self.window_size = (
|
||||
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
|
||||
)
|
||||
|
||||
def get_attn_mask(self, height, width, dtype, device):
|
||||
if self.shift_size > 0:
|
||||
|
@ -39,7 +39,7 @@ from ...file_utils import (
|
||||
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, logging
|
||||
from ...utils import ModelOutput, logging, torch_int
|
||||
from ...utils.backbone_utils import load_backbone
|
||||
from .configuration_dpt import DPTConfig
|
||||
|
||||
@ -226,7 +226,7 @@ class DPTViTEmbeddings(nn.Module):
|
||||
posemb_tok = posemb[:, :start_index]
|
||||
posemb_grid = posemb[0, start_index:]
|
||||
|
||||
old_grid_size = int(math.sqrt(len(posemb_grid)))
|
||||
old_grid_size = torch_int(posemb_grid.size(0) ** 0.5)
|
||||
|
||||
posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
|
||||
posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
|
||||
|
@ -33,7 +33,13 @@ from ...modeling_outputs import (
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_float,
|
||||
)
|
||||
from .configuration_imagegpt import ImageGPTConfig
|
||||
|
||||
|
||||
@ -229,7 +235,7 @@ class ImageGPTAttention(nn.Module):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
if self.scale_attn_weights:
|
||||
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
|
||||
attn_weights = attn_weights / torch_float(value.size(-1) ** 0.5)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
if self.scale_attn_by_inverse_layer_idx:
|
||||
|
@ -33,7 +33,13 @@ from ...modeling_outputs import (
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_layoutlmv3 import LayoutLMv3Config
|
||||
|
||||
|
||||
@ -910,8 +916,8 @@ class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
|
||||
patch_height = patch_width = None
|
||||
if pixel_values is not None:
|
||||
patch_height, patch_width = (
|
||||
int(pixel_values.shape[2] / self.config.patch_size),
|
||||
int(pixel_values.shape[3] / self.config.patch_size),
|
||||
torch_int(pixel_values.shape[2] / self.config.patch_size),
|
||||
torch_int(pixel_values.shape[3] / self.config.patch_size),
|
||||
)
|
||||
visual_embeddings = self.forward_image(pixel_values)
|
||||
visual_attention_mask = torch.ones(
|
||||
|
@ -39,6 +39,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_mobilevit import MobileViTConfig
|
||||
|
||||
@ -437,8 +438,16 @@ class MobileViTLayer(nn.Module):
|
||||
|
||||
batch_size, channels, orig_height, orig_width = features.shape
|
||||
|
||||
new_height = int(math.ceil(orig_height / patch_height) * patch_height)
|
||||
new_width = int(math.ceil(orig_width / patch_width) * patch_width)
|
||||
new_height = (
|
||||
torch_int(torch.ceil(orig_height / patch_height) * patch_height)
|
||||
if torch.jit.is_tracing()
|
||||
else int(math.ceil(orig_height / patch_height) * patch_height)
|
||||
)
|
||||
new_width = (
|
||||
torch_int(torch.ceil(orig_width / patch_width) * patch_width)
|
||||
if torch.jit.is_tracing()
|
||||
else int(math.ceil(orig_width / patch_width) * patch_width)
|
||||
)
|
||||
|
||||
interpolate = False
|
||||
if new_width != orig_width or new_height != orig_height:
|
||||
|
@ -15,7 +15,6 @@
|
||||
"""PyTorch SAM model."""
|
||||
|
||||
import collections
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
@ -232,7 +231,7 @@ class SamAttention(nn.Module):
|
||||
# SamAttention
|
||||
_, _, _, c_per_head = query.shape
|
||||
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
|
||||
attn = attn / math.sqrt(c_per_head)
|
||||
attn = attn / (c_per_head**0.5)
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
|
||||
if attention_similarity is not None:
|
||||
|
@ -36,6 +36,7 @@ from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_swin import SwinConfig
|
||||
@ -639,8 +640,10 @@ class SwinLayer(nn.Module):
|
||||
def set_shift_and_window_size(self, input_resolution):
|
||||
if min(input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.window_size = min(input_resolution)
|
||||
self.shift_size = torch_int(0)
|
||||
self.window_size = (
|
||||
torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
|
||||
)
|
||||
|
||||
def get_attn_mask(self, height, width, dtype, device):
|
||||
if self.shift_size > 0:
|
||||
|
@ -60,6 +60,8 @@ from .generic import (
|
||||
tensor_size,
|
||||
to_numpy,
|
||||
to_py_obj,
|
||||
torch_float,
|
||||
torch_int,
|
||||
transpose,
|
||||
working_or_temp_dir,
|
||||
)
|
||||
|
@ -753,6 +753,30 @@ def infer_framework(model_class):
|
||||
raise TypeError(f"Could not infer framework from class {model_class}.")
|
||||
|
||||
|
||||
def torch_int(x):
|
||||
"""
|
||||
Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return int(x)
|
||||
|
||||
import torch
|
||||
|
||||
return x.to(torch.int64) if torch.jit.is_tracing() else int(x)
|
||||
|
||||
|
||||
def torch_float(x):
|
||||
"""
|
||||
Casts an input to a torch float32 tensor if we are in a tracing context, otherwise to a Python float.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return int(x)
|
||||
|
||||
import torch
|
||||
|
||||
return x.to(torch.float32) if torch.jit.is_tracing() else int(x)
|
||||
|
||||
|
||||
def filter_out_non_signature_kwargs(extra: Optional[list] = None):
|
||||
"""
|
||||
Decorator to filter out named arguments that are not in the function signature.
|
||||
|
Loading…
Reference in New Issue
Block a user