Fix torch meshgrid warnings (#20475)

* fix torch meshgrid warnings

* support lower torch versions

* don't edit examples

* dont edit examples

* fix ci

* fix style

* rebase cleanup

* fix ci again
This commit is contained in:
fxmarty 2022-11-29 14:38:23 +01:00 committed by GitHub
parent ae1cffaf3c
commit 3b91f96fc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 43 additions and 19 deletions

View File

@ -34,7 +34,7 @@ from ...modeling_outputs import (
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
@ -446,7 +446,7 @@ class BeitRelativePositionBias(nn.Module):
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2

View File

@ -33,7 +33,7 @@ from ...modeling_outputs import (
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
@ -457,7 +457,7 @@ class Data2VecVisionRelativePositionBias(nn.Module):
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2

View File

@ -41,6 +41,7 @@ from ...file_utils import (
)
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_ninja_available, logging
from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels
@ -1144,9 +1145,10 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
reference_points_list = []
for level, (height, width) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
ref_y, ref_x = meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
indexing="ij",
)
# TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
@ -1555,9 +1557,10 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = torch.meshgrid(
grid_y, grid_x = meshgrid(
torch.linspace(0, height - 1, height, dtype=torch.float32, device=enc_output.device),
torch.linspace(0, width - 1, width, dtype=torch.float32, device=enc_output.device),
indexing="ij",
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)

View File

@ -28,7 +28,7 @@ from torch import nn
from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
@ -355,7 +355,7 @@ class DonutSwinSelfAttention(nn.Module):
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()

View File

@ -28,7 +28,7 @@ from ...activations import ACT2FN
from ...file_utils import ModelOutput
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from .configuration_maskformer_swin import MaskFormerSwinConfig
@ -315,7 +315,7 @@ class MaskFormerSwinSelfAttention(nn.Module):
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()

View File

@ -30,7 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import (
ModelOutput,
add_start_docstrings,
@ -2704,7 +2704,7 @@ def build_linear_positions(index_dims, output_range=(-1.0, 1.0)):
return torch.linspace(start=output_range[0], end=output_range[1], steps=n_xels_per_dim, dtype=torch.float32)
dim_ranges = [_linspace(n_xels_per_dim) for n_xels_per_dim in index_dims]
array_index_grid = torch.meshgrid(*dim_ranges)
array_index_grid = meshgrid(*dim_ranges, indexing="ij")
return torch.stack(array_index_grid, dim=-1)

View File

@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
@ -426,7 +426,7 @@ class SwinSelfAttention(nn.Module):
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()

View File

@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
@ -439,7 +439,7 @@ class Swinv2SelfAttention(nn.Module):
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
relative_coords_table = (
torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
.permute(1, 2, 0)
.contiguous()
.unsqueeze(0)
@ -459,7 +459,7 @@ class Swinv2SelfAttention(nn.Module):
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()

View File

@ -34,7 +34,12 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, is_torch_greater_or_equal_than_1_10, prune_linear_layer
from ...pytorch_utils import (
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_1_10,
meshgrid,
prune_linear_layer,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_vilt import ViltConfig
@ -138,7 +143,7 @@ class ViltEmbeddings(nn.Module):
x = x.flatten(2).transpose(1, 2)
# Set `device` here, otherwise `patch_index` will always be on `CPU` and will fail near the end for torch>=1.13
patch_index = torch.stack(
torch.meshgrid(torch.arange(x_mask.shape[-2]), torch.arange(x_mask.shape[-1]), indexing="ij"), dim=-1
meshgrid(torch.arange(x_mask.shape[-2]), torch.arange(x_mask.shape[-1]), indexing="ij"), dim=-1
).to(device=x_mask.device)
patch_index = patch_index[None, None, :, :, :]
patch_index = patch_index.expand(x_mask.shape[0], x_mask.shape[1], -1, -1, -1)

View File

@ -270,3 +270,19 @@ def find_pruneable_heads_and_indices(
mask = mask.view(-1).contiguous().eq(1)
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
return heads, index
def meshgrid(
*tensors: Union[torch.Tensor, List[torch.Tensor]], indexing: Optional[str] = None
) -> Tuple[torch.Tensor, ...]:
"""
Wrapper around torch.meshgrid to avoid warning messages about the introduced `indexing` argument.
Reference: https://pytorch.org/docs/1.13/generated/torch.meshgrid.html
"""
if is_torch_greater_or_equal_than_1_10:
return torch.meshgrid(*tensors, indexing=indexing)
else:
if indexing != "ij":
raise ValueError('torch.meshgrid only supports `indexing="ij"` for torch<1.10.')
return torch.meshgrid(*tensors)