diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 1084e7136a4..f380c3c3b48 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -523,14 +523,14 @@ class DeformableDetrSinePositionEmbedding(nn.Module): def forward(self, pixel_values, pixel_mask): if pixel_mask is None: raise ValueError("No pixel mask provided") - y_embed = pixel_mask.cumsum(1, dtype=torch.float32) - x_embed = pixel_mask.cumsum(2, dtype=torch.float32) + y_embed = pixel_mask.cumsum(1, dtype=pixel_values.dtype) + x_embed = pixel_mask.cumsum(2, dtype=pixel_values.dtype) if self.normalize: eps = 1e-6 y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float() + dim_t = torch.arange(self.embedding_dim, dtype=pixel_values.dtype, device=pixel_values.device) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) pos_x = x_embed[:, :, :, None] / dim_t @@ -580,11 +580,14 @@ def build_position_encoding(config): def multi_scale_deformable_attention( - value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor + value: Tensor, + value_spatial_shapes: Union[Tensor, List[Tuple]], + sampling_locations: Tensor, + attention_weights: Tensor, ) -> Tensor: batch_size, _, num_heads, hidden_dim = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1) + value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for level_id, (height, width) in enumerate(value_spatial_shapes): @@ -672,6 +675,7 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module): position_embeddings: Optional[torch.Tensor] = None, reference_points=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, ): @@ -681,7 +685,8 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module): batch_size, num_queries, _ = hidden_states.shape batch_size, sequence_length, _ = encoder_hidden_states.shape - if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + total_elements = sum(height * width for height, width in spatial_shapes_list) + if total_elements != sequence_length: raise ValueError( "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" ) @@ -716,9 +721,11 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module): else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") - if self.disable_custom_kernels: + if self.disable_custom_kernels or MultiScaleDeformableAttention is None: # PyTorch implementation - output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) + output = multi_scale_deformable_attention( + value, spatial_shapes_list, sampling_locations, attention_weights + ) else: try: # custom kernel @@ -732,7 +739,9 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module): ) except Exception: # PyTorch implementation - output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) + output = multi_scale_deformable_attention( + value, spatial_shapes_list, sampling_locations, attention_weights + ) output = self.output_proj(output) return output, attention_weights @@ -877,6 +886,7 @@ class DeformableDetrEncoderLayer(nn.Module): position_embeddings: torch.Tensor = None, reference_points=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, ): @@ -909,6 +919,7 @@ class DeformableDetrEncoderLayer(nn.Module): position_embeddings=position_embeddings, reference_points=reference_points, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) @@ -974,6 +985,7 @@ class DeformableDetrDecoderLayer(nn.Module): position_embeddings: Optional[torch.Tensor] = None, reference_points=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, @@ -1025,6 +1037,7 @@ class DeformableDetrDecoderLayer(nn.Module): position_embeddings=position_embeddings, reference_points=reference_points, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) @@ -1216,6 +1229,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel): attention_mask=None, position_embeddings=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, valid_ratios=None, output_attentions=None, @@ -1257,7 +1271,8 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel): hidden_states = inputs_embeds hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device) + spatial_shapes_tuple = tuple(spatial_shapes_list) + reference_points = self.get_reference_points(spatial_shapes_tuple, valid_ratios, device=inputs_embeds.device) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1272,6 +1287,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel): position_embeddings, reference_points, spatial_shapes, + spatial_shapes_list, level_start_index, output_attentions, ) @@ -1282,6 +1298,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel): position_embeddings=position_embeddings, reference_points=reference_points, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) @@ -1338,6 +1355,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel): position_embeddings=None, reference_points=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, valid_ratios=None, output_attentions=None, @@ -1413,6 +1431,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel): position_embeddings, reference_points_input, spatial_shapes, + spatial_shapes_list, level_start_index, encoder_hidden_states, encoder_attention_mask, @@ -1425,6 +1444,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel): encoder_hidden_states=encoder_hidden_states, reference_points=reference_points_input, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, @@ -1586,7 +1606,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): temperature = 10000 scale = 2 * math.pi - dim_t = torch.arange(num_pos_feats, dtype=torch.int64, device=proposals.device).float() + dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device) dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) # batch_size, num_queries, 4 proposals = proposals.sigmoid() * scale @@ -1717,7 +1737,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): source = self.input_proj[level](features[-1][0]) else: source = self.input_proj[level](sources[-1]) - mask = nn.functional.interpolate(pixel_mask[None].float(), size=source.shape[-2:]).to(torch.bool)[0] + mask = nn.functional.interpolate(pixel_mask[None].to(pixel_values.dtype), size=source.shape[-2:]).to( + torch.bool + )[0] pos_l = self.backbone.position_embedding(source, mask).to(source.dtype) sources.append(source) masks.append(mask) @@ -1732,11 +1754,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): source_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] - spatial_shapes = [] + spatial_shapes_list = [] for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)): batch_size, num_channels, height, width = source.shape spatial_shape = (height, width) - spatial_shapes.append(spatial_shape) + spatial_shapes_list.append(spatial_shape) source = source.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) @@ -1747,7 +1769,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): source_flatten = torch.cat(source_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) - spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device) + spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1) @@ -1759,6 +1781,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): attention_mask=mask_flatten, position_embeddings=lvl_pos_embed_flatten, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, valid_ratios=valid_ratios, output_attentions=output_attentions, @@ -1816,6 +1839,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): encoder_attention_mask=mask_flatten, reference_points=reference_points, spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, valid_ratios=valid_ratios, output_attentions=output_attentions, diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 08e4b27af64..aaac7488f43 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -583,11 +583,14 @@ def build_position_encoding(config): # Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention def multi_scale_deformable_attention( - value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor + value: Tensor, + value_spatial_shapes: Union[Tensor, List[Tuple]], + sampling_locations: Tensor, + attention_weights: Tensor, ) -> Tensor: batch_size, _, num_heads, hidden_dim = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1) + value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for level_id, (height, width) in enumerate(value_spatial_shapes): @@ -676,6 +679,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module): position_embeddings: Optional[torch.Tensor] = None, reference_points=None, spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, ): @@ -685,6 +689,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module): batch_size, num_queries, _ = hidden_states.shape batch_size, sequence_length, _ = encoder_hidden_states.shape + # Ignore copy if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: raise ValueError( "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" @@ -720,7 +725,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module): else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") - if self.disable_custom_kernels: + if self.disable_custom_kernels or MultiScaleDeformableAttention is None: # PyTorch implementation output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) else: diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index c5788951fd5..6b94caf355d 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -17,7 +17,7 @@ import math import warnings from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -800,11 +800,14 @@ class Mask2FormerLoss(nn.Module): # Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention def multi_scale_deformable_attention( - value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor + value: Tensor, + value_spatial_shapes: Union[Tensor, List[Tuple]], + sampling_locations: Tensor, + attention_weights: Tensor, ) -> Tensor: batch_size, _, num_heads, hidden_dim = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1) + value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for level_id, (height, width) in enumerate(value_spatial_shapes): diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 9c2f6622071..aeeccb68a92 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -18,7 +18,7 @@ import copy import math import warnings from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -63,11 +63,14 @@ def _get_clones(module, N): # Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention def multi_scale_deformable_attention( - value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor + value: Tensor, + value_spatial_shapes: Union[Tensor, List[Tuple]], + sampling_locations: Tensor, + attention_weights: Tensor, ) -> Tensor: batch_size, _, num_heads, hidden_dim = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1) + value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for level_id, (height, width) in enumerate(value_spatial_shapes): diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 35af2ec8ecf..c4daba6d274 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -733,13 +733,14 @@ class RTDetrCSPRepLayer(nn.Module): # Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention def multi_scale_deformable_attention( - value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor + value: Tensor, + value_spatial_shapes: Union[Tensor, List[Tuple]], + sampling_locations: Tensor, + attention_weights: Tensor, ) -> Tensor: batch_size, _, num_heads, hidden_dim = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape - # Ignore copy value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1) - sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for level_id, (height, width) in enumerate(value_spatial_shapes): @@ -838,9 +839,7 @@ class RTDetrMultiscaleDeformableAttention(nn.Module): batch_size, num_queries, _ = hidden_states.shape batch_size, sequence_length, _ = encoder_hidden_states.shape - - # Ignore copy - total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list) + total_elements = sum(height * width for height, width in spatial_shapes_list) if total_elements != sequence_length: raise ValueError( "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" @@ -876,7 +875,6 @@ class RTDetrMultiscaleDeformableAttention(nn.Module): else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") - # Ignore copy if self.disable_custom_kernels or MultiScaleDeformableAttention is None: # PyTorch implementation output = multi_scale_deformable_attention(