prepare for "__floordiv__ is deprecated and its behavior will change in a future version of pytorch" (#20211)

* rounding_mode = "floor"  instead of // to prevent behavioral change

* add other TODO

* use `torch_int_div` from pytrch_utils

* same for tests

* fix copies

* style

* use relative imports when needed

* Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
Arthur 2023-03-01 10:49:21 +01:00 committed by GitHub
parent b29e2dcaff
commit 44e3e3fb49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 53 additions and 51 deletions

View File

@ -37,7 +37,7 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...pytorch_utils import apply_chunking_to_forward, torch_int_div
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
@ -971,11 +971,8 @@ class BigBirdBlockSparseAttention(nn.Module):
num_indices_to_gather = indices.shape[-2] * indices.shape[-1]
num_indices_to_pick_from = params.shape[2]
indices_shift = (
torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)
// num_indices_to_gather
* num_indices_to_pick_from
)
shift = torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)
indices_shift = torch_int_div(shift, num_indices_to_gather) * num_indices_to_pick_from
flattened_indices = indices.view(-1) + indices_shift
flattened_params = params.reshape(-1, params.shape[-2], params.shape[-1])

View File

@ -36,6 +36,7 @@ from ...modeling_outputs import (
Seq2SeqSequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
from ...utils import (
add_code_sample_docstrings,
add_end_docstrings,
@ -789,11 +790,8 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
num_indices_to_gather = indices.shape[-2] * indices.shape[-1]
num_indices_to_pick_from = params.shape[2]
indices_shift = (
torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)
// num_indices_to_gather
* num_indices_to_pick_from
)
shift = torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)
indices_shift = torch_int_div(shift, num_indices_to_gather) * num_indices_to_pick_from
flattened_indices = indices.view(-1) + indices_shift
flattened_params = params.reshape(-1, params.shape[-2], params.shape[-1])

View File

@ -67,6 +67,8 @@ if is_torch_available():
import torch
from torch import nn
from transformers.pytorch_utils import torch_int_div
if is_vision_available():
import PIL
@ -1311,7 +1313,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 300, dim=1)
scores = topk_values
topk_boxes = topk_indexes // out_logits.shape[2]
topk_boxes = torch_int_div(topk_indexes, out_logits.shape[2])
labels = topk_indexes % out_logits.shape[2]
boxes = center_to_corners_format(out_bbox)
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
@ -1357,7 +1359,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
scores = topk_values
topk_boxes = topk_indexes // out_logits.shape[2]
topk_boxes = torch_int_div(topk_indexes, out_logits.shape[2])
labels = topk_indexes % out_logits.shape[2]
boxes = center_to_corners_format(out_bbox)
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))

View File

@ -504,7 +504,7 @@ def build_position_encoding(config):
def gen_sine_position_embeddings(pos_tensor):
scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (dim_t // 2) / 128)
dim_t = 10000 ** (2 * torch_int_div(dim_t, 2) / 128)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t

View File

@ -67,6 +67,8 @@ if is_torch_available():
import torch
from torch import nn
from ...pytorch_utils import torch_int_div
if is_vision_available():
import PIL
@ -1309,7 +1311,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
scores = topk_values
topk_boxes = topk_indexes // out_logits.shape[2]
topk_boxes = torch_int_div(topk_indexes, out_logits.shape[2])
labels = topk_indexes % out_logits.shape[2]
boxes = center_to_corners_format(out_bbox)
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
@ -1354,7 +1356,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
scores = topk_values
topk_boxes = topk_indexes // out_logits.shape[2]
topk_boxes = torch_int_div(topk_indexes, out_logits.shape[2])
labels = topk_indexes % out_logits.shape[2]
boxes = center_to_corners_format(out_bbox)
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))

View File

@ -41,7 +41,7 @@ from ...file_utils import (
)
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...pytorch_utils import meshgrid, torch_int_div
from ...utils import is_ninja_available, logging
from ..auto import AutoBackbone
from .configuration_deformable_detr import DeformableDetrConfig
@ -497,7 +497,7 @@ class DeformableDetrSinePositionEmbedding(nn.Module):
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.embedding_dim)
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2 / self.embedding_dim))
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
@ -1552,7 +1552,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
dim_t = temperature ** (2 * torch.div(dim_t, 2) / num_pos_feats)
# batch_size, num_queries, 4
proposals = proposals.sigmoid() * scale
# batch_size, num_queries, 4, 128

View File

@ -63,6 +63,8 @@ from ...utils.generic import ExplicitEnum, TensorType
if is_torch_available():
import torch
from ...pytorch_utils import torch_int_div
if is_torchvision_available():
from torchvision.ops.boxes import batched_nms
@ -965,7 +967,7 @@ class DetaImageProcessor(BaseImageProcessor):
all_scores = prob.view(batch_size, num_queries * num_labels).to(out_logits.device)
all_indexes = torch.arange(num_queries * num_labels)[None].repeat(batch_size, 1).to(out_logits.device)
all_boxes = all_indexes // out_logits.shape[2]
all_boxes = torch_int_div(all_indexes, out_logits.shape[2])
all_labels = all_indexes % out_logits.shape[2]
boxes = center_to_corners_format(out_bbox)

View File

@ -36,7 +36,7 @@ from ...file_utils import (
)
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...pytorch_utils import meshgrid, torch_int_div
from ...utils import is_torchvision_available, logging, requires_backends
from ..auto import AutoBackbone
from .configuration_deta import DetaConfig
@ -399,7 +399,7 @@ class DetaSinePositionEmbedding(nn.Module):
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.embedding_dim)
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2 / self.embedding_dim))
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
@ -1463,7 +1463,7 @@ class DetaModel(DetaPreTrainedModel):
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
dim_t = temperature ** (2 * torch.div(dim_t, 2) / num_pos_feats)
# batch_size, num_queries, 4
proposals = proposals.sigmoid() * scale
# batch_size, num_queries, 4, 128

View File

@ -57,6 +57,8 @@ if is_torch_available():
import torch
from torch import nn
from ...pytorch_utils import torch_int_div
# Copied from transformers.models.detr.image_processing_detr.max_across_indices
def max_across_indices(values: Iterable[Any]) -> List[Any]:
@ -1007,7 +1009,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
labels_per_image = labels[topk_indices]
topk_indices = topk_indices // num_classes
topk_indices = torch_int_div(topk_indices, num_classes)
mask_pred = mask_pred[topk_indices]
pred_masks = (mask_pred > 0).float()

View File

@ -61,6 +61,8 @@ if is_torch_available():
import torch
from torch import nn
from ...pytorch_utils import torch_int_div
# Copied from transformers.models.detr.image_processing_detr.max_across_indices
def max_across_indices(values: Iterable[Any]) -> List[Any]:
@ -1075,7 +1077,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
labels_per_image = labels[topk_indices]
topk_indices = topk_indices // num_classes
topk_indices = torch_int_div(topk_indices, num_classes)
mask_pred = mask_pred[topk_indices]
pred_masks = (mask_pred > 0).float()

View File

@ -58,6 +58,8 @@ if is_torch_available():
import torch
from torch import nn
from ...pytorch_utils import torch_int_div
# Copied from transformers.models.detr.image_processing_detr.max_across_indices
def max_across_indices(values: Iterable[Any]) -> List[Any]:
@ -1120,7 +1122,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
labels_per_image = labels[topk_indices]
topk_indices = topk_indices // num_classes
topk_indices = torch_int_div(topk_indices, num_classes)
# mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
mask_pred = masks_queries_logits[i][topk_indices]

View File

@ -909,10 +909,9 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
relevant_bucket_idx_chunk = bucket_idx[tuple(relevant_bucket_idx_chunk.transpose(0, 1))]
# adapt bucket_idx for batch and hidden states for index select
offset = torch.arange(relevant_bucket_idx_chunk.shape[-1], device=hidden_states.device, dtype=torch.long)
bucket_idx_batch_offset = sequence_length * (
batch_size
* torch.arange(relevant_bucket_idx_chunk.shape[-1], device=hidden_states.device, dtype=torch.long)
// relevant_bucket_idx_chunk.shape[-1]
batch_size * torch.div(offset, relevant_bucket_idx_chunk.shape[-1], rounding_mode="floor")
)
# add batch offset

View File

@ -29,7 +29,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
torch_int_div,
)
from ...utils import (
ModelOutput,
add_start_docstrings,
@ -1632,11 +1637,8 @@ class ProductIndexMap(IndexMap):
def project_outer(self, index):
"""Projects an index with the same index set onto the outer components."""
return IndexMap(
indices=(index.indices // self.inner_index.num_segments).type(torch.float).floor().type(torch.long),
num_segments=self.outer_index.num_segments,
batch_dims=index.batch_dims,
)
indices = torch_int_div(index.indices, self.inner_index.num_segments).type(torch.long)
return IndexMap(indices=indices, num_segments=self.outer_index.num_segments, batch_dims=index.batch_dims)
def project_inner(self, index):
"""Projects an index with the same index set onto the inner components."""

View File

@ -32,6 +32,7 @@ if is_torch_available():
DisjunctiveConstraint,
PhrasalConstraint,
)
from transformers.pytorch_utils import torch_int_div
class BeamSearchTester:
@ -160,10 +161,8 @@ class BeamSearchTester:
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
expected_output_indices = (
cut_expected_tensor(next_indices)
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
)
offset = torch_int_div(torch.arange(self.num_beams * self.batch_size, device=torch_device), self.num_beams)
expected_output_indices = cut_expected_tensor(next_indices) + offset * self.num_beams
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
@ -399,10 +398,8 @@ class ConstrainedBeamSearchTester:
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
expected_output_indices = (
cut_expected_tensor(next_indices)
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
)
offset = torch_int_div(torch.arange(self.num_beams * self.batch_size, device=torch_device), self.num_beams)
expected_output_indices = cut_expected_tensor(next_indices) + offset * self.num_beams
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())

View File

@ -71,7 +71,7 @@ if is_torch_available():
_compute_mask_indices,
_sample_negative_indices,
)
from transformers.pytorch_utils import is_torch_less_than_1_9
from transformers.pytorch_utils import is_torch_less_than_1_9, torch_int_div
else:
is_torch_less_than_1_9 = True
@ -1217,10 +1217,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
sequence_length = 10
hidden_size = 4
num_negatives = 3
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
) # each value in vector consits of same value
sequence = torch_int_div(torch.arange(sequence_length * hidden_size, device=torch_device), hidden_size)
features = sequence.view(sequence_length, hidden_size) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
# sample negative indices
@ -1247,9 +1245,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
mask[-1, sequence_length // 2 :] = 0
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
) # each value in vector consits of same value
sequence = torch_int_div(torch.arange(sequence_length * hidden_size, device=torch_device), hidden_size)
features = sequence.view(sequence_length, hidden_size) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
# replace masked feature vectors with -100 to test that those are not sampled