mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Optim deformable detr (#33600)
* optimize deformable detr * fix copies * remove deformable_detr_basline * fix hardcoded float16 and .float() * [run slow] deformable-detr,grounding-dino,mask2former,oneformer,rt-detr * [run slow] deformable_detr,grounding_dino,mask2former,oneformer,rt_detr
This commit is contained in:
parent
cac4a4876b
commit
ee71c9853a
@ -523,14 +523,14 @@ class DeformableDetrSinePositionEmbedding(nn.Module):
|
||||
def forward(self, pixel_values, pixel_mask):
|
||||
if pixel_mask is None:
|
||||
raise ValueError("No pixel mask provided")
|
||||
y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
|
||||
y_embed = pixel_mask.cumsum(1, dtype=pixel_values.dtype)
|
||||
x_embed = pixel_mask.cumsum(2, dtype=pixel_values.dtype)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
|
||||
dim_t = torch.arange(self.embedding_dim, dtype=pixel_values.dtype, device=pixel_values.device)
|
||||
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
@ -580,11 +580,14 @@ def build_position_encoding(config):
|
||||
|
||||
|
||||
def multi_scale_deformable_attention(
|
||||
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||
sampling_locations: Tensor,
|
||||
attention_weights: Tensor,
|
||||
) -> Tensor:
|
||||
batch_size, _, num_heads, hidden_dim = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
|
||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||
@ -672,6 +675,7 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
||||
position_embeddings: Optional[torch.Tensor] = None,
|
||||
reference_points=None,
|
||||
spatial_shapes=None,
|
||||
spatial_shapes_list=None,
|
||||
level_start_index=None,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
@ -681,7 +685,8 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
||||
|
||||
batch_size, num_queries, _ = hidden_states.shape
|
||||
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
||||
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
||||
total_elements = sum(height * width for height, width in spatial_shapes_list)
|
||||
if total_elements != sequence_length:
|
||||
raise ValueError(
|
||||
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
||||
)
|
||||
@ -716,9 +721,11 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||
|
||||
if self.disable_custom_kernels:
|
||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
|
||||
output = multi_scale_deformable_attention(
|
||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# custom kernel
|
||||
@ -732,7 +739,9 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
||||
)
|
||||
except Exception:
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
|
||||
output = multi_scale_deformable_attention(
|
||||
value, spatial_shapes_list, sampling_locations, attention_weights
|
||||
)
|
||||
output = self.output_proj(output)
|
||||
|
||||
return output, attention_weights
|
||||
@ -877,6 +886,7 @@ class DeformableDetrEncoderLayer(nn.Module):
|
||||
position_embeddings: torch.Tensor = None,
|
||||
reference_points=None,
|
||||
spatial_shapes=None,
|
||||
spatial_shapes_list=None,
|
||||
level_start_index=None,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
@ -909,6 +919,7 @@ class DeformableDetrEncoderLayer(nn.Module):
|
||||
position_embeddings=position_embeddings,
|
||||
reference_points=reference_points,
|
||||
spatial_shapes=spatial_shapes,
|
||||
spatial_shapes_list=spatial_shapes_list,
|
||||
level_start_index=level_start_index,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -974,6 +985,7 @@ class DeformableDetrDecoderLayer(nn.Module):
|
||||
position_embeddings: Optional[torch.Tensor] = None,
|
||||
reference_points=None,
|
||||
spatial_shapes=None,
|
||||
spatial_shapes_list=None,
|
||||
level_start_index=None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
@ -1025,6 +1037,7 @@ class DeformableDetrDecoderLayer(nn.Module):
|
||||
position_embeddings=position_embeddings,
|
||||
reference_points=reference_points,
|
||||
spatial_shapes=spatial_shapes,
|
||||
spatial_shapes_list=spatial_shapes_list,
|
||||
level_start_index=level_start_index,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -1216,6 +1229,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
||||
attention_mask=None,
|
||||
position_embeddings=None,
|
||||
spatial_shapes=None,
|
||||
spatial_shapes_list=None,
|
||||
level_start_index=None,
|
||||
valid_ratios=None,
|
||||
output_attentions=None,
|
||||
@ -1257,7 +1271,8 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
||||
hidden_states = inputs_embeds
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device)
|
||||
spatial_shapes_tuple = tuple(spatial_shapes_list)
|
||||
reference_points = self.get_reference_points(spatial_shapes_tuple, valid_ratios, device=inputs_embeds.device)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
@ -1272,6 +1287,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
||||
position_embeddings,
|
||||
reference_points,
|
||||
spatial_shapes,
|
||||
spatial_shapes_list,
|
||||
level_start_index,
|
||||
output_attentions,
|
||||
)
|
||||
@ -1282,6 +1298,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
||||
position_embeddings=position_embeddings,
|
||||
reference_points=reference_points,
|
||||
spatial_shapes=spatial_shapes,
|
||||
spatial_shapes_list=spatial_shapes_list,
|
||||
level_start_index=level_start_index,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@ -1338,6 +1355,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
||||
position_embeddings=None,
|
||||
reference_points=None,
|
||||
spatial_shapes=None,
|
||||
spatial_shapes_list=None,
|
||||
level_start_index=None,
|
||||
valid_ratios=None,
|
||||
output_attentions=None,
|
||||
@ -1413,6 +1431,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
||||
position_embeddings,
|
||||
reference_points_input,
|
||||
spatial_shapes,
|
||||
spatial_shapes_list,
|
||||
level_start_index,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
@ -1425,6 +1444,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
reference_points=reference_points_input,
|
||||
spatial_shapes=spatial_shapes,
|
||||
spatial_shapes_list=spatial_shapes_list,
|
||||
level_start_index=level_start_index,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
@ -1586,7 +1606,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
temperature = 10000
|
||||
scale = 2 * math.pi
|
||||
|
||||
dim_t = torch.arange(num_pos_feats, dtype=torch.int64, device=proposals.device).float()
|
||||
dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device)
|
||||
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
|
||||
# batch_size, num_queries, 4
|
||||
proposals = proposals.sigmoid() * scale
|
||||
@ -1717,7 +1737,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
source = self.input_proj[level](features[-1][0])
|
||||
else:
|
||||
source = self.input_proj[level](sources[-1])
|
||||
mask = nn.functional.interpolate(pixel_mask[None].float(), size=source.shape[-2:]).to(torch.bool)[0]
|
||||
mask = nn.functional.interpolate(pixel_mask[None].to(pixel_values.dtype), size=source.shape[-2:]).to(
|
||||
torch.bool
|
||||
)[0]
|
||||
pos_l = self.backbone.position_embedding(source, mask).to(source.dtype)
|
||||
sources.append(source)
|
||||
masks.append(mask)
|
||||
@ -1732,11 +1754,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
source_flatten = []
|
||||
mask_flatten = []
|
||||
lvl_pos_embed_flatten = []
|
||||
spatial_shapes = []
|
||||
spatial_shapes_list = []
|
||||
for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)):
|
||||
batch_size, num_channels, height, width = source.shape
|
||||
spatial_shape = (height, width)
|
||||
spatial_shapes.append(spatial_shape)
|
||||
spatial_shapes_list.append(spatial_shape)
|
||||
source = source.flatten(2).transpose(1, 2)
|
||||
mask = mask.flatten(1)
|
||||
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
||||
@ -1747,7 +1769,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
source_flatten = torch.cat(source_flatten, 1)
|
||||
mask_flatten = torch.cat(mask_flatten, 1)
|
||||
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
||||
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
|
||||
spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
|
||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
|
||||
|
||||
@ -1759,6 +1781,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
attention_mask=mask_flatten,
|
||||
position_embeddings=lvl_pos_embed_flatten,
|
||||
spatial_shapes=spatial_shapes,
|
||||
spatial_shapes_list=spatial_shapes_list,
|
||||
level_start_index=level_start_index,
|
||||
valid_ratios=valid_ratios,
|
||||
output_attentions=output_attentions,
|
||||
@ -1816,6 +1839,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
encoder_attention_mask=mask_flatten,
|
||||
reference_points=reference_points,
|
||||
spatial_shapes=spatial_shapes,
|
||||
spatial_shapes_list=spatial_shapes_list,
|
||||
level_start_index=level_start_index,
|
||||
valid_ratios=valid_ratios,
|
||||
output_attentions=output_attentions,
|
||||
|
@ -583,11 +583,14 @@ def build_position_encoding(config):
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||
def multi_scale_deformable_attention(
|
||||
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||
sampling_locations: Tensor,
|
||||
attention_weights: Tensor,
|
||||
) -> Tensor:
|
||||
batch_size, _, num_heads, hidden_dim = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
|
||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||
@ -676,6 +679,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
|
||||
position_embeddings: Optional[torch.Tensor] = None,
|
||||
reference_points=None,
|
||||
spatial_shapes=None,
|
||||
spatial_shapes_list=None,
|
||||
level_start_index=None,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
@ -685,6 +689,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
|
||||
|
||||
batch_size, num_queries, _ = hidden_states.shape
|
||||
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
||||
# Ignore copy
|
||||
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
||||
raise ValueError(
|
||||
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
||||
@ -720,7 +725,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||
|
||||
if self.disable_custom_kernels:
|
||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
|
||||
else:
|
||||
|
@ -17,7 +17,7 @@
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -800,11 +800,14 @@ class Mask2FormerLoss(nn.Module):
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||
def multi_scale_deformable_attention(
|
||||
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||
sampling_locations: Tensor,
|
||||
attention_weights: Tensor,
|
||||
) -> Tensor:
|
||||
batch_size, _, num_heads, hidden_dim = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
|
||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||
|
@ -18,7 +18,7 @@ import copy
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -63,11 +63,14 @@ def _get_clones(module, N):
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||
def multi_scale_deformable_attention(
|
||||
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||
sampling_locations: Tensor,
|
||||
attention_weights: Tensor,
|
||||
) -> Tensor:
|
||||
batch_size, _, num_heads, hidden_dim = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
|
||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||
|
@ -733,13 +733,14 @@ class RTDetrCSPRepLayer(nn.Module):
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||
def multi_scale_deformable_attention(
|
||||
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
|
||||
value: Tensor,
|
||||
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||
sampling_locations: Tensor,
|
||||
attention_weights: Tensor,
|
||||
) -> Tensor:
|
||||
batch_size, _, num_heads, hidden_dim = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
# Ignore copy
|
||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||
@ -838,9 +839,7 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
|
||||
|
||||
batch_size, num_queries, _ = hidden_states.shape
|
||||
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
||||
|
||||
# Ignore copy
|
||||
total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list)
|
||||
total_elements = sum(height * width for height, width in spatial_shapes_list)
|
||||
if total_elements != sequence_length:
|
||||
raise ValueError(
|
||||
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
||||
@ -876,7 +875,6 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||
|
||||
# Ignore copy
|
||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
|
||||
# PyTorch implementation
|
||||
output = multi_scale_deformable_attention(
|
||||
|
Loading…
Reference in New Issue
Block a user