mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 10:41:07 +06:00
Optim deformable detr (#33600)
* optimize deformable detr * fix copies * remove deformable_detr_basline * fix hardcoded float16 and .float() * [run slow] deformable-detr,grounding-dino,mask2former,oneformer,rt-detr * [run slow] deformable_detr,grounding_dino,mask2former,oneformer,rt_detr
This commit is contained in:
parent
cac4a4876b
commit
ee71c9853a
@ -523,14 +523,14 @@ class DeformableDetrSinePositionEmbedding(nn.Module):
|
|||||||
def forward(self, pixel_values, pixel_mask):
|
def forward(self, pixel_values, pixel_mask):
|
||||||
if pixel_mask is None:
|
if pixel_mask is None:
|
||||||
raise ValueError("No pixel mask provided")
|
raise ValueError("No pixel mask provided")
|
||||||
y_embed = pixel_mask.cumsum(1, dtype=torch.float32)
|
y_embed = pixel_mask.cumsum(1, dtype=pixel_values.dtype)
|
||||||
x_embed = pixel_mask.cumsum(2, dtype=torch.float32)
|
x_embed = pixel_mask.cumsum(2, dtype=pixel_values.dtype)
|
||||||
if self.normalize:
|
if self.normalize:
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
||||||
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
||||||
|
|
||||||
dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
|
dim_t = torch.arange(self.embedding_dim, dtype=pixel_values.dtype, device=pixel_values.device)
|
||||||
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
|
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
|
||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
pos_x = x_embed[:, :, :, None] / dim_t
|
||||||
@ -580,11 +580,14 @@ def build_position_encoding(config):
|
|||||||
|
|
||||||
|
|
||||||
def multi_scale_deformable_attention(
|
def multi_scale_deformable_attention(
|
||||||
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
|
value: Tensor,
|
||||||
|
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||||
|
sampling_locations: Tensor,
|
||||||
|
attention_weights: Tensor,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
batch_size, _, num_heads, hidden_dim = value.shape
|
batch_size, _, num_heads, hidden_dim = value.shape
|
||||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||||
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
|
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||||
sampling_grids = 2 * sampling_locations - 1
|
sampling_grids = 2 * sampling_locations - 1
|
||||||
sampling_value_list = []
|
sampling_value_list = []
|
||||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||||
@ -672,6 +675,7 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
|||||||
position_embeddings: Optional[torch.Tensor] = None,
|
position_embeddings: Optional[torch.Tensor] = None,
|
||||||
reference_points=None,
|
reference_points=None,
|
||||||
spatial_shapes=None,
|
spatial_shapes=None,
|
||||||
|
spatial_shapes_list=None,
|
||||||
level_start_index=None,
|
level_start_index=None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
):
|
):
|
||||||
@ -681,7 +685,8 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
|||||||
|
|
||||||
batch_size, num_queries, _ = hidden_states.shape
|
batch_size, num_queries, _ = hidden_states.shape
|
||||||
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
||||||
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
total_elements = sum(height * width for height, width in spatial_shapes_list)
|
||||||
|
if total_elements != sequence_length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
||||||
)
|
)
|
||||||
@ -716,9 +721,11 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||||
|
|
||||||
if self.disable_custom_kernels:
|
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
|
||||||
# PyTorch implementation
|
# PyTorch implementation
|
||||||
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
|
output = multi_scale_deformable_attention(
|
||||||
|
value, spatial_shapes_list, sampling_locations, attention_weights
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
# custom kernel
|
# custom kernel
|
||||||
@ -732,7 +739,9 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# PyTorch implementation
|
# PyTorch implementation
|
||||||
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
|
output = multi_scale_deformable_attention(
|
||||||
|
value, spatial_shapes_list, sampling_locations, attention_weights
|
||||||
|
)
|
||||||
output = self.output_proj(output)
|
output = self.output_proj(output)
|
||||||
|
|
||||||
return output, attention_weights
|
return output, attention_weights
|
||||||
@ -877,6 +886,7 @@ class DeformableDetrEncoderLayer(nn.Module):
|
|||||||
position_embeddings: torch.Tensor = None,
|
position_embeddings: torch.Tensor = None,
|
||||||
reference_points=None,
|
reference_points=None,
|
||||||
spatial_shapes=None,
|
spatial_shapes=None,
|
||||||
|
spatial_shapes_list=None,
|
||||||
level_start_index=None,
|
level_start_index=None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
):
|
):
|
||||||
@ -909,6 +919,7 @@ class DeformableDetrEncoderLayer(nn.Module):
|
|||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
reference_points=reference_points,
|
reference_points=reference_points,
|
||||||
spatial_shapes=spatial_shapes,
|
spatial_shapes=spatial_shapes,
|
||||||
|
spatial_shapes_list=spatial_shapes_list,
|
||||||
level_start_index=level_start_index,
|
level_start_index=level_start_index,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@ -974,6 +985,7 @@ class DeformableDetrDecoderLayer(nn.Module):
|
|||||||
position_embeddings: Optional[torch.Tensor] = None,
|
position_embeddings: Optional[torch.Tensor] = None,
|
||||||
reference_points=None,
|
reference_points=None,
|
||||||
spatial_shapes=None,
|
spatial_shapes=None,
|
||||||
|
spatial_shapes_list=None,
|
||||||
level_start_index=None,
|
level_start_index=None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
@ -1025,6 +1037,7 @@ class DeformableDetrDecoderLayer(nn.Module):
|
|||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
reference_points=reference_points,
|
reference_points=reference_points,
|
||||||
spatial_shapes=spatial_shapes,
|
spatial_shapes=spatial_shapes,
|
||||||
|
spatial_shapes_list=spatial_shapes_list,
|
||||||
level_start_index=level_start_index,
|
level_start_index=level_start_index,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@ -1216,6 +1229,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
position_embeddings=None,
|
position_embeddings=None,
|
||||||
spatial_shapes=None,
|
spatial_shapes=None,
|
||||||
|
spatial_shapes_list=None,
|
||||||
level_start_index=None,
|
level_start_index=None,
|
||||||
valid_ratios=None,
|
valid_ratios=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
@ -1257,7 +1271,8 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
|||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device)
|
spatial_shapes_tuple = tuple(spatial_shapes_list)
|
||||||
|
reference_points = self.get_reference_points(spatial_shapes_tuple, valid_ratios, device=inputs_embeds.device)
|
||||||
|
|
||||||
encoder_states = () if output_hidden_states else None
|
encoder_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
@ -1272,6 +1287,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
|||||||
position_embeddings,
|
position_embeddings,
|
||||||
reference_points,
|
reference_points,
|
||||||
spatial_shapes,
|
spatial_shapes,
|
||||||
|
spatial_shapes_list,
|
||||||
level_start_index,
|
level_start_index,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
@ -1282,6 +1298,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
|||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
reference_points=reference_points,
|
reference_points=reference_points,
|
||||||
spatial_shapes=spatial_shapes,
|
spatial_shapes=spatial_shapes,
|
||||||
|
spatial_shapes_list=spatial_shapes_list,
|
||||||
level_start_index=level_start_index,
|
level_start_index=level_start_index,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@ -1338,6 +1355,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
|||||||
position_embeddings=None,
|
position_embeddings=None,
|
||||||
reference_points=None,
|
reference_points=None,
|
||||||
spatial_shapes=None,
|
spatial_shapes=None,
|
||||||
|
spatial_shapes_list=None,
|
||||||
level_start_index=None,
|
level_start_index=None,
|
||||||
valid_ratios=None,
|
valid_ratios=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
@ -1413,6 +1431,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
|||||||
position_embeddings,
|
position_embeddings,
|
||||||
reference_points_input,
|
reference_points_input,
|
||||||
spatial_shapes,
|
spatial_shapes,
|
||||||
|
spatial_shapes_list,
|
||||||
level_start_index,
|
level_start_index,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
@ -1425,6 +1444,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
reference_points=reference_points_input,
|
reference_points=reference_points_input,
|
||||||
spatial_shapes=spatial_shapes,
|
spatial_shapes=spatial_shapes,
|
||||||
|
spatial_shapes_list=spatial_shapes_list,
|
||||||
level_start_index=level_start_index,
|
level_start_index=level_start_index,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@ -1586,7 +1606,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
temperature = 10000
|
temperature = 10000
|
||||||
scale = 2 * math.pi
|
scale = 2 * math.pi
|
||||||
|
|
||||||
dim_t = torch.arange(num_pos_feats, dtype=torch.int64, device=proposals.device).float()
|
dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device)
|
||||||
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
|
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
|
||||||
# batch_size, num_queries, 4
|
# batch_size, num_queries, 4
|
||||||
proposals = proposals.sigmoid() * scale
|
proposals = proposals.sigmoid() * scale
|
||||||
@ -1717,7 +1737,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
source = self.input_proj[level](features[-1][0])
|
source = self.input_proj[level](features[-1][0])
|
||||||
else:
|
else:
|
||||||
source = self.input_proj[level](sources[-1])
|
source = self.input_proj[level](sources[-1])
|
||||||
mask = nn.functional.interpolate(pixel_mask[None].float(), size=source.shape[-2:]).to(torch.bool)[0]
|
mask = nn.functional.interpolate(pixel_mask[None].to(pixel_values.dtype), size=source.shape[-2:]).to(
|
||||||
|
torch.bool
|
||||||
|
)[0]
|
||||||
pos_l = self.backbone.position_embedding(source, mask).to(source.dtype)
|
pos_l = self.backbone.position_embedding(source, mask).to(source.dtype)
|
||||||
sources.append(source)
|
sources.append(source)
|
||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
@ -1732,11 +1754,11 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
source_flatten = []
|
source_flatten = []
|
||||||
mask_flatten = []
|
mask_flatten = []
|
||||||
lvl_pos_embed_flatten = []
|
lvl_pos_embed_flatten = []
|
||||||
spatial_shapes = []
|
spatial_shapes_list = []
|
||||||
for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)):
|
for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)):
|
||||||
batch_size, num_channels, height, width = source.shape
|
batch_size, num_channels, height, width = source.shape
|
||||||
spatial_shape = (height, width)
|
spatial_shape = (height, width)
|
||||||
spatial_shapes.append(spatial_shape)
|
spatial_shapes_list.append(spatial_shape)
|
||||||
source = source.flatten(2).transpose(1, 2)
|
source = source.flatten(2).transpose(1, 2)
|
||||||
mask = mask.flatten(1)
|
mask = mask.flatten(1)
|
||||||
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
||||||
@ -1747,7 +1769,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
source_flatten = torch.cat(source_flatten, 1)
|
source_flatten = torch.cat(source_flatten, 1)
|
||||||
mask_flatten = torch.cat(mask_flatten, 1)
|
mask_flatten = torch.cat(mask_flatten, 1)
|
||||||
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
||||||
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
|
spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
|
||||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||||
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
|
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
|
||||||
|
|
||||||
@ -1759,6 +1781,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
attention_mask=mask_flatten,
|
attention_mask=mask_flatten,
|
||||||
position_embeddings=lvl_pos_embed_flatten,
|
position_embeddings=lvl_pos_embed_flatten,
|
||||||
spatial_shapes=spatial_shapes,
|
spatial_shapes=spatial_shapes,
|
||||||
|
spatial_shapes_list=spatial_shapes_list,
|
||||||
level_start_index=level_start_index,
|
level_start_index=level_start_index,
|
||||||
valid_ratios=valid_ratios,
|
valid_ratios=valid_ratios,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@ -1816,6 +1839,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
encoder_attention_mask=mask_flatten,
|
encoder_attention_mask=mask_flatten,
|
||||||
reference_points=reference_points,
|
reference_points=reference_points,
|
||||||
spatial_shapes=spatial_shapes,
|
spatial_shapes=spatial_shapes,
|
||||||
|
spatial_shapes_list=spatial_shapes_list,
|
||||||
level_start_index=level_start_index,
|
level_start_index=level_start_index,
|
||||||
valid_ratios=valid_ratios,
|
valid_ratios=valid_ratios,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
@ -583,11 +583,14 @@ def build_position_encoding(config):
|
|||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||||
def multi_scale_deformable_attention(
|
def multi_scale_deformable_attention(
|
||||||
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
|
value: Tensor,
|
||||||
|
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||||
|
sampling_locations: Tensor,
|
||||||
|
attention_weights: Tensor,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
batch_size, _, num_heads, hidden_dim = value.shape
|
batch_size, _, num_heads, hidden_dim = value.shape
|
||||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||||
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
|
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||||
sampling_grids = 2 * sampling_locations - 1
|
sampling_grids = 2 * sampling_locations - 1
|
||||||
sampling_value_list = []
|
sampling_value_list = []
|
||||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||||
@ -676,6 +679,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
|
|||||||
position_embeddings: Optional[torch.Tensor] = None,
|
position_embeddings: Optional[torch.Tensor] = None,
|
||||||
reference_points=None,
|
reference_points=None,
|
||||||
spatial_shapes=None,
|
spatial_shapes=None,
|
||||||
|
spatial_shapes_list=None,
|
||||||
level_start_index=None,
|
level_start_index=None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
):
|
):
|
||||||
@ -685,6 +689,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
|
|||||||
|
|
||||||
batch_size, num_queries, _ = hidden_states.shape
|
batch_size, num_queries, _ = hidden_states.shape
|
||||||
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
||||||
|
# Ignore copy
|
||||||
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
||||||
@ -720,7 +725,7 @@ class GroundingDinoMultiscaleDeformableAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||||
|
|
||||||
if self.disable_custom_kernels:
|
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
|
||||||
# PyTorch implementation
|
# PyTorch implementation
|
||||||
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
|
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
|
||||||
else:
|
else:
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -800,11 +800,14 @@ class Mask2FormerLoss(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||||
def multi_scale_deformable_attention(
|
def multi_scale_deformable_attention(
|
||||||
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
|
value: Tensor,
|
||||||
|
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||||
|
sampling_locations: Tensor,
|
||||||
|
attention_weights: Tensor,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
batch_size, _, num_heads, hidden_dim = value.shape
|
batch_size, _, num_heads, hidden_dim = value.shape
|
||||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||||
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
|
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||||
sampling_grids = 2 * sampling_locations - 1
|
sampling_grids = 2 * sampling_locations - 1
|
||||||
sampling_value_list = []
|
sampling_value_list = []
|
||||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||||
|
@ -18,7 +18,7 @@ import copy
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -63,11 +63,14 @@ def _get_clones(module, N):
|
|||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||||
def multi_scale_deformable_attention(
|
def multi_scale_deformable_attention(
|
||||||
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
|
value: Tensor,
|
||||||
|
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||||
|
sampling_locations: Tensor,
|
||||||
|
attention_weights: Tensor,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
batch_size, _, num_heads, hidden_dim = value.shape
|
batch_size, _, num_heads, hidden_dim = value.shape
|
||||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||||
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
|
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||||
sampling_grids = 2 * sampling_locations - 1
|
sampling_grids = 2 * sampling_locations - 1
|
||||||
sampling_value_list = []
|
sampling_value_list = []
|
||||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||||
|
@ -733,13 +733,14 @@ class RTDetrCSPRepLayer(nn.Module):
|
|||||||
|
|
||||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||||
def multi_scale_deformable_attention(
|
def multi_scale_deformable_attention(
|
||||||
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
|
value: Tensor,
|
||||||
|
value_spatial_shapes: Union[Tensor, List[Tuple]],
|
||||||
|
sampling_locations: Tensor,
|
||||||
|
attention_weights: Tensor,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
batch_size, _, num_heads, hidden_dim = value.shape
|
batch_size, _, num_heads, hidden_dim = value.shape
|
||||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||||
# Ignore copy
|
|
||||||
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
|
||||||
|
|
||||||
sampling_grids = 2 * sampling_locations - 1
|
sampling_grids = 2 * sampling_locations - 1
|
||||||
sampling_value_list = []
|
sampling_value_list = []
|
||||||
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
for level_id, (height, width) in enumerate(value_spatial_shapes):
|
||||||
@ -838,9 +839,7 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
|
|||||||
|
|
||||||
batch_size, num_queries, _ = hidden_states.shape
|
batch_size, num_queries, _ = hidden_states.shape
|
||||||
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
||||||
|
total_elements = sum(height * width for height, width in spatial_shapes_list)
|
||||||
# Ignore copy
|
|
||||||
total_elements = sum(shape[0] * shape[1] for shape in spatial_shapes_list)
|
|
||||||
if total_elements != sequence_length:
|
if total_elements != sequence_length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
||||||
@ -876,7 +875,6 @@ class RTDetrMultiscaleDeformableAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
||||||
|
|
||||||
# Ignore copy
|
|
||||||
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
|
if self.disable_custom_kernels or MultiScaleDeformableAttention is None:
|
||||||
# PyTorch implementation
|
# PyTorch implementation
|
||||||
output = multi_scale_deformable_attention(
|
output = multi_scale_deformable_attention(
|
||||||
|
Loading…
Reference in New Issue
Block a user