mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix torch meshgrid warnings (#20475)
* fix torch meshgrid warnings * support lower torch versions * don't edit examples * dont edit examples * fix ci * fix style * rebase cleanup * fix ci again
This commit is contained in:
parent
ae1cffaf3c
commit
3b91f96fc9
@ -34,7 +34,7 @@ from ...modeling_outputs import (
|
||||
SemanticSegmenterOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -446,7 +446,7 @@ class BeitRelativePositionBias(nn.Module):
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
|
@ -33,7 +33,7 @@ from ...modeling_outputs import (
|
||||
SemanticSegmenterOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -457,7 +457,7 @@ class Data2VecVisionRelativePositionBias(nn.Module):
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
|
@ -41,6 +41,7 @@ from ...file_utils import (
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import meshgrid
|
||||
from ...utils import is_ninja_available, logging
|
||||
from .configuration_deformable_detr import DeformableDetrConfig
|
||||
from .load_custom import load_cuda_kernels
|
||||
@ -1144,9 +1145,10 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
||||
reference_points_list = []
|
||||
for level, (height, width) in enumerate(spatial_shapes):
|
||||
|
||||
ref_y, ref_x = torch.meshgrid(
|
||||
ref_y, ref_x = meshgrid(
|
||||
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
|
||||
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
|
||||
indexing="ij",
|
||||
)
|
||||
# TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
|
||||
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
|
||||
@ -1555,9 +1557,10 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
|
||||
valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
|
||||
|
||||
grid_y, grid_x = torch.meshgrid(
|
||||
grid_y, grid_x = meshgrid(
|
||||
torch.linspace(0, height - 1, height, dtype=torch.float32, device=enc_output.device),
|
||||
torch.linspace(0, width - 1, width, dtype=torch.float32, device=enc_output.device),
|
||||
indexing="ij",
|
||||
)
|
||||
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
|
||||
|
||||
|
@ -28,7 +28,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
@ -355,7 +355,7 @@ class DonutSwinSelfAttention(nn.Module):
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
|
@ -28,7 +28,7 @@ from ...activations import ACT2FN
|
||||
from ...file_utils import ModelOutput
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from .configuration_maskformer_swin import MaskFormerSwinConfig
|
||||
|
||||
|
||||
@ -315,7 +315,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
|
@ -30,7 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
@ -2704,7 +2704,7 @@ def build_linear_positions(index_dims, output_range=(-1.0, 1.0)):
|
||||
return torch.linspace(start=output_range[0], end=output_range[1], steps=n_xels_per_dim, dtype=torch.float32)
|
||||
|
||||
dim_ranges = [_linspace(n_xels_per_dim) for n_xels_per_dim in index_dims]
|
||||
array_index_grid = torch.meshgrid(*dim_ranges)
|
||||
array_index_grid = meshgrid(*dim_ranges, indexing="ij")
|
||||
|
||||
return torch.stack(array_index_grid, dim=-1)
|
||||
|
||||
|
@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
@ -426,7 +426,7 @@ class SwinSelfAttention(nn.Module):
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
|
@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
@ -439,7 +439,7 @@ class Swinv2SelfAttention(nn.Module):
|
||||
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
|
||||
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
|
||||
relative_coords_table = (
|
||||
torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
|
||||
torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
|
||||
.permute(1, 2, 0)
|
||||
.contiguous()
|
||||
.unsqueeze(0)
|
||||
@ -459,7 +459,7 @@ class Swinv2SelfAttention(nn.Module):
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
|
@ -34,7 +34,12 @@ from ...modeling_outputs import (
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, is_torch_greater_or_equal_than_1_10, prune_linear_layer
|
||||
from ...pytorch_utils import (
|
||||
find_pruneable_heads_and_indices,
|
||||
is_torch_greater_or_equal_than_1_10,
|
||||
meshgrid,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from .configuration_vilt import ViltConfig
|
||||
|
||||
@ -138,7 +143,7 @@ class ViltEmbeddings(nn.Module):
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
# Set `device` here, otherwise `patch_index` will always be on `CPU` and will fail near the end for torch>=1.13
|
||||
patch_index = torch.stack(
|
||||
torch.meshgrid(torch.arange(x_mask.shape[-2]), torch.arange(x_mask.shape[-1]), indexing="ij"), dim=-1
|
||||
meshgrid(torch.arange(x_mask.shape[-2]), torch.arange(x_mask.shape[-1]), indexing="ij"), dim=-1
|
||||
).to(device=x_mask.device)
|
||||
patch_index = patch_index[None, None, :, :, :]
|
||||
patch_index = patch_index.expand(x_mask.shape[0], x_mask.shape[1], -1, -1, -1)
|
||||
|
@ -270,3 +270,19 @@ def find_pruneable_heads_and_indices(
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
|
||||
return heads, index
|
||||
|
||||
|
||||
def meshgrid(
|
||||
*tensors: Union[torch.Tensor, List[torch.Tensor]], indexing: Optional[str] = None
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Wrapper around torch.meshgrid to avoid warning messages about the introduced `indexing` argument.
|
||||
|
||||
Reference: https://pytorch.org/docs/1.13/generated/torch.meshgrid.html
|
||||
"""
|
||||
if is_torch_greater_or_equal_than_1_10:
|
||||
return torch.meshgrid(*tensors, indexing=indexing)
|
||||
else:
|
||||
if indexing != "ij":
|
||||
raise ValueError('torch.meshgrid only supports `indexing="ij"` for torch<1.10.')
|
||||
return torch.meshgrid(*tensors)
|
||||
|
Loading…
Reference in New Issue
Block a user