Optim deformable detr (#33600)

* optimize deformable detr

* fix copies

* remove deformable_detr_basline

* fix hardcoded float16 and .float()

* [run slow] deformable-detr,grounding-dino,mask2former,oneformer,rt-detr

* [run slow] deformable_detr,grounding_dino,mask2former,oneformer,rt_detr
This commit is contained in:
Yoni Gozlan 2024-10-02 15:46:27 +02:00 committed by GitHub
parent cac4a4876b
commit ee71c9853a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 64 additions and 31 deletions

View File

@ -523,14 +523,14 @@ class DeformableDetrSinePositionEmbedding(nn.Module):
def forward(self, pixel_values, pixel_mask):
if pixel_mask is None:
raise ValueError("No pixel mask provided")
y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
y_embed = pixel_mask.cumsum(1, dtype=pixel_values.dtype)
x_embed = pixel_mask.cumsum(2, dtype=pixel_values.dtype)
if self.normalize:
eps = 1e-6
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
dim_t = torch.arange(self.embedding_dim, dtype=pixel_values.dtype, device=pixel_values.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
pos_x = x_embed[:, :, :, None] / dim_t
@ -580,11 +580,14 @@ def build_position_encoding(config):
def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
@ -672,6 +675,7 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
):
@ -681,7 +685,8 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
total_elements = sum(height * width for height, width in spatial_shapes_list)
if total_elements != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
@ -716,9 +721,11 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
if self.disable_custom_kernels:
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
)
else:
try:
# custom kernel
@ -732,7 +739,9 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
)
except Exception:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = multi_scale_deformable_attention(
value, spatial_shapes_list, sampling_locations, attention_weights
)
output = self.output_proj(output)
return output, attention_weights
@ -877,6 +886,7 @@ class DeformableDetrEncoderLayer(nn.Module):
position_embeddings: torch.Tensor = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
):
@ -909,6 +919,7 @@ class DeformableDetrEncoderLayer(nn.Module):
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
@ -974,6 +985,7 @@ class DeformableDetrDecoderLayer(nn.Module):
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
@ -1025,6 +1037,7 @@ class DeformableDetrDecoderLayer(nn.Module):
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
@ -1216,6 +1229,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
attention_mask=None,
position_embeddings=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
@ -1257,7 +1271,8 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
hidden_states = inputs_embeds
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device)
spatial_shapes_tuple = tuple(spatial_shapes_list)
reference_points = self.get_reference_points(spatial_shapes_tuple, valid_ratios, device=inputs_embeds.device)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@ -1272,6 +1287,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
position_embeddings,
reference_points,
spatial_shapes,
spatial_shapes_list,
level_start_index,
output_attentions,
)
@ -1282,6 +1298,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
@ -1338,6 +1355,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
position_embeddings=None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
@ -1413,6 +1431,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
position_embeddings,
reference_points_input,
spatial_shapes,
spatial_shapes_list,
level_start_index,
encoder_hidden_states,
encoder_attention_mask,
@ -1425,6 +1444,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
reference_points=reference_points_input,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
@ -1586,7 +1606,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
temperature = 10000
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.int64, device=proposals.device).float()
dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
# batch_size, num_queries, 4
proposals = proposals.sigmoid() * scale
@ -1717,7 +1737,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
source = self.input_proj[level](features[-1][0])
else:
source = self.input_proj[level](sources[-1])
mask = nn.functional.interpolate(pixel_mask[None].float(), size=source.shape[-2:]).to(torch.bool)[0]
mask = nn.functional.interpolate(pixel_mask[None].to(pixel_values.dtype), size=source.shape[-2:]).to(
torch.bool
)[0]
pos_l = self.backbone.position_embedding(source, mask).to(source.dtype)
sources.append(source)
masks.append(mask)
@ -1732,11 +1754,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
source_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
spatial_shapes_list = []
for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)):
batch_size, num_channels, height, width = source.shape
spatial_shape = (height, width)
spatial_shapes.append(spatial_shape)
spatial_shapes_list.append(spatial_shape)
source = source.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
@ -1747,7 +1769,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
source_flatten = torch.cat(source_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
@ -1759,6 +1781,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
attention_mask=mask_flatten,
position_embeddings=lvl_pos_embed_flatten,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
output_attentions=output_attentions,
@ -1816,6 +1839,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
encoder_attention_mask=mask_flatten,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
output_attentions=output_attentions,

View File

@ -583,11 +583,14 @@ def build_position_encoding(config):
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
@ -676,6 +679,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
):
@ -685,6 +689,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
# Ignore copy
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
@ -720,7 +725,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
if self.disable_custom_kernels:
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
else:

View File

@ -17,7 +17,7 @@
import math
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@ -800,11 +800,14 @@ class Mask2FormerLoss(nn.Module):
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):

View File

@ -18,7 +18,7 @@ import copy
import math
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@ -63,11 +63,14 @@ def _get_clones(module, N):
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):

View File

@ -733,13 +733,14 @@ class RTDetrCSPRepLayer(nn.Module):
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
value: Tensor,
value_spatial_shapes: Union[Tensor, List[Tuple]],
sampling_locations: Tensor,
attention_weights: Tensor,
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
# Ignore copy
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
@ -838,9 +839,7 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
# Ignore copy
total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list)
total_elements = sum(height * width for height, width in spatial_shapes_list)
if total_elements != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
@ -876,7 +875,6 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
# Ignore copy
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
# PyTorch implementation
output = multi_scale_deformable_attention(