mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
prepare for "__floordiv__ is deprecated and its behavior will change in a future version of pytorch" (#20211)
* rounding_mode = "floor" instead of // to prevent behavioral change * add other TODO * use `torch_int_div` from pytrch_utils * same for tests * fix copies * style * use relative imports when needed * Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
parent
b29e2dcaff
commit
44e3e3fb49
@ -37,7 +37,7 @@ from ...modeling_outputs import (
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward
|
||||
from ...pytorch_utils import apply_chunking_to_forward, torch_int_div
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
@ -971,11 +971,8 @@ class BigBirdBlockSparseAttention(nn.Module):
|
||||
num_indices_to_gather = indices.shape[-2] * indices.shape[-1]
|
||||
num_indices_to_pick_from = params.shape[2]
|
||||
|
||||
indices_shift = (
|
||||
torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)
|
||||
// num_indices_to_gather
|
||||
* num_indices_to_pick_from
|
||||
)
|
||||
shift = torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)
|
||||
indices_shift = torch_int_div(shift, num_indices_to_gather) * num_indices_to_pick_from
|
||||
|
||||
flattened_indices = indices.view(-1) + indices_shift
|
||||
flattened_params = params.reshape(-1, params.shape[-2], params.shape[-1])
|
||||
|
@ -36,6 +36,7 @@ from ...modeling_outputs import (
|
||||
Seq2SeqSequenceClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import torch_int_div
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_end_docstrings,
|
||||
@ -789,11 +790,8 @@ class BigBirdPegasusBlockSparseAttention(nn.Module):
|
||||
num_indices_to_gather = indices.shape[-2] * indices.shape[-1]
|
||||
num_indices_to_pick_from = params.shape[2]
|
||||
|
||||
indices_shift = (
|
||||
torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)
|
||||
// num_indices_to_gather
|
||||
* num_indices_to_pick_from
|
||||
)
|
||||
shift = torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device)
|
||||
indices_shift = torch_int_div(shift, num_indices_to_gather) * num_indices_to_pick_from
|
||||
|
||||
flattened_indices = indices.view(-1) + indices_shift
|
||||
flattened_params = params.reshape(-1, params.shape[-2], params.shape[-1])
|
||||
|
@ -67,6 +67,8 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.pytorch_utils import torch_int_div
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
@ -1311,7 +1313,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
prob = out_logits.sigmoid()
|
||||
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 300, dim=1)
|
||||
scores = topk_values
|
||||
topk_boxes = topk_indexes // out_logits.shape[2]
|
||||
topk_boxes = torch_int_div(topk_indexes, out_logits.shape[2])
|
||||
labels = topk_indexes % out_logits.shape[2]
|
||||
boxes = center_to_corners_format(out_bbox)
|
||||
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
||||
@ -1357,7 +1359,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
prob = out_logits.sigmoid()
|
||||
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
|
||||
scores = topk_values
|
||||
topk_boxes = topk_indexes // out_logits.shape[2]
|
||||
topk_boxes = torch_int_div(topk_indexes, out_logits.shape[2])
|
||||
labels = topk_indexes % out_logits.shape[2]
|
||||
boxes = center_to_corners_format(out_bbox)
|
||||
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
||||
|
@ -504,7 +504,7 @@ def build_position_encoding(config):
|
||||
def gen_sine_position_embeddings(pos_tensor):
|
||||
scale = 2 * math.pi
|
||||
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
|
||||
dim_t = 10000 ** (2 * (dim_t // 2) / 128)
|
||||
dim_t = 10000 ** (2 * torch_int_div(dim_t, 2) / 128)
|
||||
x_embed = pos_tensor[:, :, 0] * scale
|
||||
y_embed = pos_tensor[:, :, 1] * scale
|
||||
pos_x = x_embed[:, :, None] / dim_t
|
||||
|
@ -67,6 +67,8 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...pytorch_utils import torch_int_div
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
@ -1309,7 +1311,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
prob = out_logits.sigmoid()
|
||||
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
|
||||
scores = topk_values
|
||||
topk_boxes = topk_indexes // out_logits.shape[2]
|
||||
topk_boxes = torch_int_div(topk_indexes, out_logits.shape[2])
|
||||
labels = topk_indexes % out_logits.shape[2]
|
||||
boxes = center_to_corners_format(out_bbox)
|
||||
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
||||
@ -1354,7 +1356,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
prob = out_logits.sigmoid()
|
||||
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
|
||||
scores = topk_values
|
||||
topk_boxes = topk_indexes // out_logits.shape[2]
|
||||
topk_boxes = torch_int_div(topk_indexes, out_logits.shape[2])
|
||||
labels = topk_indexes % out_logits.shape[2]
|
||||
boxes = center_to_corners_format(out_bbox)
|
||||
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
||||
|
@ -41,7 +41,7 @@ from ...file_utils import (
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import meshgrid
|
||||
from ...pytorch_utils import meshgrid, torch_int_div
|
||||
from ...utils import is_ninja_available, logging
|
||||
from ..auto import AutoBackbone
|
||||
from .configuration_deformable_detr import DeformableDetrConfig
|
||||
@ -497,7 +497,7 @@ class DeformableDetrSinePositionEmbedding(nn.Module):
|
||||
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.embedding_dim)
|
||||
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2 / self.embedding_dim))
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
@ -1552,7 +1552,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
scale = 2 * math.pi
|
||||
|
||||
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
|
||||
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
|
||||
dim_t = temperature ** (2 * torch.div(dim_t, 2) / num_pos_feats)
|
||||
# batch_size, num_queries, 4
|
||||
proposals = proposals.sigmoid() * scale
|
||||
# batch_size, num_queries, 4, 128
|
||||
|
@ -63,6 +63,8 @@ from ...utils.generic import ExplicitEnum, TensorType
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ...pytorch_utils import torch_int_div
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision.ops.boxes import batched_nms
|
||||
|
||||
@ -965,7 +967,7 @@ class DetaImageProcessor(BaseImageProcessor):
|
||||
|
||||
all_scores = prob.view(batch_size, num_queries * num_labels).to(out_logits.device)
|
||||
all_indexes = torch.arange(num_queries * num_labels)[None].repeat(batch_size, 1).to(out_logits.device)
|
||||
all_boxes = all_indexes // out_logits.shape[2]
|
||||
all_boxes = torch_int_div(all_indexes, out_logits.shape[2])
|
||||
all_labels = all_indexes % out_logits.shape[2]
|
||||
|
||||
boxes = center_to_corners_format(out_bbox)
|
||||
|
@ -36,7 +36,7 @@ from ...file_utils import (
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import meshgrid
|
||||
from ...pytorch_utils import meshgrid, torch_int_div
|
||||
from ...utils import is_torchvision_available, logging, requires_backends
|
||||
from ..auto import AutoBackbone
|
||||
from .configuration_deta import DetaConfig
|
||||
@ -399,7 +399,7 @@ class DetaSinePositionEmbedding(nn.Module):
|
||||
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.embedding_dim)
|
||||
dim_t = self.temperature ** (2 * torch_int_div(dim_t, 2 / self.embedding_dim))
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
@ -1463,7 +1463,7 @@ class DetaModel(DetaPreTrainedModel):
|
||||
scale = 2 * math.pi
|
||||
|
||||
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
|
||||
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
|
||||
dim_t = temperature ** (2 * torch.div(dim_t, 2) / num_pos_feats)
|
||||
# batch_size, num_queries, 4
|
||||
proposals = proposals.sigmoid() * scale
|
||||
# batch_size, num_queries, 4, 128
|
||||
|
@ -57,6 +57,8 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...pytorch_utils import torch_int_div
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.max_across_indices
|
||||
def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
@ -1007,7 +1009,7 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||
scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
|
||||
labels_per_image = labels[topk_indices]
|
||||
|
||||
topk_indices = topk_indices // num_classes
|
||||
topk_indices = torch_int_div(topk_indices, num_classes)
|
||||
mask_pred = mask_pred[topk_indices]
|
||||
pred_masks = (mask_pred > 0).float()
|
||||
|
||||
|
@ -61,6 +61,8 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...pytorch_utils import torch_int_div
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.max_across_indices
|
||||
def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
@ -1075,7 +1077,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||
scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
|
||||
labels_per_image = labels[topk_indices]
|
||||
|
||||
topk_indices = topk_indices // num_classes
|
||||
topk_indices = torch_int_div(topk_indices, num_classes)
|
||||
mask_pred = mask_pred[topk_indices]
|
||||
pred_masks = (mask_pred > 0).float()
|
||||
|
||||
|
@ -58,6 +58,8 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...pytorch_utils import torch_int_div
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.image_processing_detr.max_across_indices
|
||||
def max_across_indices(values: Iterable[Any]) -> List[Any]:
|
||||
@ -1120,7 +1122,7 @@ class OneFormerImageProcessor(BaseImageProcessor):
|
||||
scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
|
||||
labels_per_image = labels[topk_indices]
|
||||
|
||||
topk_indices = topk_indices // num_classes
|
||||
topk_indices = torch_int_div(topk_indices, num_classes)
|
||||
# mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
|
||||
mask_pred = masks_queries_logits[i][topk_indices]
|
||||
|
||||
|
@ -909,10 +909,9 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
relevant_bucket_idx_chunk = bucket_idx[tuple(relevant_bucket_idx_chunk.transpose(0, 1))]
|
||||
|
||||
# adapt bucket_idx for batch and hidden states for index select
|
||||
offset = torch.arange(relevant_bucket_idx_chunk.shape[-1], device=hidden_states.device, dtype=torch.long)
|
||||
bucket_idx_batch_offset = sequence_length * (
|
||||
batch_size
|
||||
* torch.arange(relevant_bucket_idx_chunk.shape[-1], device=hidden_states.device, dtype=torch.long)
|
||||
// relevant_bucket_idx_chunk.shape[-1]
|
||||
batch_size * torch.div(offset, relevant_bucket_idx_chunk.shape[-1], rounding_mode="floor")
|
||||
)
|
||||
|
||||
# add batch offset
|
||||
|
@ -29,7 +29,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...pytorch_utils import (
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
torch_int_div,
|
||||
)
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
@ -1632,11 +1637,8 @@ class ProductIndexMap(IndexMap):
|
||||
|
||||
def project_outer(self, index):
|
||||
"""Projects an index with the same index set onto the outer components."""
|
||||
return IndexMap(
|
||||
indices=(index.indices // self.inner_index.num_segments).type(torch.float).floor().type(torch.long),
|
||||
num_segments=self.outer_index.num_segments,
|
||||
batch_dims=index.batch_dims,
|
||||
)
|
||||
indices = torch_int_div(index.indices, self.inner_index.num_segments).type(torch.long)
|
||||
return IndexMap(indices=indices, num_segments=self.outer_index.num_segments, batch_dims=index.batch_dims)
|
||||
|
||||
def project_inner(self, index):
|
||||
"""Projects an index with the same index set onto the inner components."""
|
||||
|
@ -32,6 +32,7 @@ if is_torch_available():
|
||||
DisjunctiveConstraint,
|
||||
PhrasalConstraint,
|
||||
)
|
||||
from transformers.pytorch_utils import torch_int_div
|
||||
|
||||
|
||||
class BeamSearchTester:
|
||||
@ -160,10 +161,8 @@ class BeamSearchTester:
|
||||
expected_output_scores = cut_expected_tensor(next_scores)
|
||||
|
||||
# add num_beams * batch_idx
|
||||
expected_output_indices = (
|
||||
cut_expected_tensor(next_indices)
|
||||
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
|
||||
)
|
||||
offset = torch_int_div(torch.arange(self.num_beams * self.batch_size, device=torch_device), self.num_beams)
|
||||
expected_output_indices = cut_expected_tensor(next_indices) + offset * self.num_beams
|
||||
|
||||
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
|
||||
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
|
||||
@ -399,10 +398,8 @@ class ConstrainedBeamSearchTester:
|
||||
expected_output_scores = cut_expected_tensor(next_scores)
|
||||
|
||||
# add num_beams * batch_idx
|
||||
expected_output_indices = (
|
||||
cut_expected_tensor(next_indices)
|
||||
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
|
||||
)
|
||||
offset = torch_int_div(torch.arange(self.num_beams * self.batch_size, device=torch_device), self.num_beams)
|
||||
expected_output_indices = cut_expected_tensor(next_indices) + offset * self.num_beams
|
||||
|
||||
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
|
||||
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
|
||||
|
@ -71,7 +71,7 @@ if is_torch_available():
|
||||
_compute_mask_indices,
|
||||
_sample_negative_indices,
|
||||
)
|
||||
from transformers.pytorch_utils import is_torch_less_than_1_9
|
||||
from transformers.pytorch_utils import is_torch_less_than_1_9, torch_int_div
|
||||
else:
|
||||
is_torch_less_than_1_9 = True
|
||||
|
||||
@ -1217,10 +1217,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
sequence_length = 10
|
||||
hidden_size = 4
|
||||
num_negatives = 3
|
||||
|
||||
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
|
||||
sequence_length, hidden_size
|
||||
) # each value in vector consits of same value
|
||||
sequence = torch_int_div(torch.arange(sequence_length * hidden_size, device=torch_device), hidden_size)
|
||||
features = sequence.view(sequence_length, hidden_size) # each value in vector consits of same value
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
# sample negative indices
|
||||
@ -1247,9 +1245,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
||||
mask[-1, sequence_length // 2 :] = 0
|
||||
|
||||
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
|
||||
sequence_length, hidden_size
|
||||
) # each value in vector consits of same value
|
||||
sequence = torch_int_div(torch.arange(sequence_length * hidden_size, device=torch_device), hidden_size)
|
||||
features = sequence.view(sequence_length, hidden_size) # each value in vector consits of same value
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
# replace masked feature vectors with -100 to test that those are not sampled
|
||||
|
Loading…
Reference in New Issue
Block a user