mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-02 12:20:05 +06:00

* No more Tuple, List, Dict * make fixup * More style fixes * Docstring fixes with regex replacement * Trigger tests * Redo fixes after rebase * Fix copies * [test all] * update * [test all] * update * [test all] * make style after rebase * Patch the hf_argparser test * Patch the hf_argparser test * style fixes * style fixes * style fixes * Fix docstrings in Cohere test * [test all] --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1693 lines
79 KiB
Python
1693 lines
79 KiB
Python
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from examples/modular-transformers/modular_test_detr.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_test_detr.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
import math
|
|
import warnings
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor, nn
|
|
|
|
from ...activations import ACT2FN
|
|
from ...integrations import use_kernel_forward_from_hub
|
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|
from ...modeling_outputs import BaseModelOutput
|
|
from ...modeling_utils import PreTrainedModel
|
|
from ...pytorch_utils import meshgrid
|
|
from ...utils import (
|
|
ModelOutput,
|
|
add_start_docstrings,
|
|
add_start_docstrings_to_model_forward,
|
|
is_timm_available,
|
|
replace_return_docstrings,
|
|
requires_backends,
|
|
)
|
|
from ...utils.backbone_utils import load_backbone
|
|
from .configuration_test_detr import TestDetrConfig
|
|
|
|
|
|
if is_timm_available():
|
|
from timm import create_model
|
|
|
|
_CONFIG_FOR_DOC = "TestDetrConfig"
|
|
|
|
|
|
@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
|
|
class MultiScaleDeformableAttention(nn.Module):
|
|
def forward(
|
|
self,
|
|
value: Tensor,
|
|
value_spatial_shapes: Tensor,
|
|
value_spatial_shapes_list: list[tuple],
|
|
level_start_index: Tensor,
|
|
sampling_locations: Tensor,
|
|
attention_weights: Tensor,
|
|
im2col_step: int,
|
|
):
|
|
batch_size, _, num_heads, hidden_dim = value.shape
|
|
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
|
value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
|
|
sampling_grids = 2 * sampling_locations - 1
|
|
sampling_value_list = []
|
|
for level_id, (height, width) in enumerate(value_spatial_shapes_list):
|
|
# batch_size, height*width, num_heads, hidden_dim
|
|
# -> batch_size, height*width, num_heads*hidden_dim
|
|
# -> batch_size, num_heads*hidden_dim, height*width
|
|
# -> batch_size*num_heads, hidden_dim, height, width
|
|
value_l_ = (
|
|
value_list[level_id]
|
|
.flatten(2)
|
|
.transpose(1, 2)
|
|
.reshape(batch_size * num_heads, hidden_dim, height, width)
|
|
)
|
|
# batch_size, num_queries, num_heads, num_points, 2
|
|
# -> batch_size, num_heads, num_queries, num_points, 2
|
|
# -> batch_size*num_heads, num_queries, num_points, 2
|
|
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
|
|
# batch_size*num_heads, hidden_dim, num_queries, num_points
|
|
sampling_value_l_ = nn.functional.grid_sample(
|
|
value_l_,
|
|
sampling_grid_l_,
|
|
mode="bilinear",
|
|
padding_mode="zeros",
|
|
align_corners=False,
|
|
)
|
|
sampling_value_list.append(sampling_value_l_)
|
|
# (batch_size, num_queries, num_heads, num_levels, num_points)
|
|
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
|
|
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
|
|
attention_weights = attention_weights.transpose(1, 2).reshape(
|
|
batch_size * num_heads, 1, num_queries, num_levels * num_points
|
|
)
|
|
output = (
|
|
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
|
.sum(-1)
|
|
.view(batch_size, num_heads * hidden_dim, num_queries)
|
|
)
|
|
return output.transpose(1, 2).contiguous()
|
|
|
|
|
|
@dataclass
|
|
class TestDetrDecoderOutput(ModelOutput):
|
|
"""
|
|
Base class for outputs of the TestDetrDecoder. This class adds two attributes to
|
|
BaseModelOutputWithCrossAttentions, namely:
|
|
- a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
|
|
- a stacked tensor of intermediate reference points.
|
|
|
|
Args:
|
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Sequence of hidden-states at the output of the last layer of the model.
|
|
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
|
Stacked intermediate hidden states (output of each layer of the decoder).
|
|
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
|
|
Stacked intermediate reference points (reference points of each layer of the decoder).
|
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
|
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
|
plus the initial embedding outputs.
|
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
|
the self-attention heads.
|
|
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
|
|
used to compute the weighted average in the cross-attention heads.
|
|
"""
|
|
|
|
last_hidden_state: Optional[torch.FloatTensor] = None
|
|
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
|
intermediate_reference_points: Optional[torch.FloatTensor] = None
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
cross_attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
|
|
|
|
@dataclass
|
|
class TestDetrModelOutput(ModelOutput):
|
|
"""
|
|
Base class for outputs of the Deformable DETR encoder-decoder model.
|
|
|
|
Args:
|
|
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
|
Initial reference points sent through the Transformer decoder.
|
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
|
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
|
Stacked intermediate hidden states (output of each layer of the decoder).
|
|
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
|
Stacked intermediate reference points (reference points of each layer of the decoder).
|
|
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
|
shape `(batch_size, num_queries, hidden_size)`. Hidden-states of the decoder at the output of each layer
|
|
plus the initial embedding outputs.
|
|
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, num_queries,
|
|
num_queries)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted
|
|
average in the self-attention heads.
|
|
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`.
|
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
|
weighted average in the cross-attention heads.
|
|
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
|
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
|
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each
|
|
layer plus the initial embedding outputs.
|
|
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_queries, num_heads, 4, 4)`.
|
|
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
|
self-attention heads.
|
|
enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
|
Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
|
|
picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
|
|
foreground and background).
|
|
enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
|
|
Logits of predicted bounding boxes coordinates in the first stage.
|
|
"""
|
|
|
|
init_reference_points: Optional[torch.FloatTensor] = None
|
|
last_hidden_state: Optional[torch.FloatTensor] = None
|
|
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
|
intermediate_reference_points: Optional[torch.FloatTensor] = None
|
|
decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
cross_attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
|
encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
|
encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
|
|
enc_outputs_class: Optional[torch.FloatTensor] = None
|
|
enc_outputs_coord_logits: Optional[torch.FloatTensor] = None
|
|
|
|
|
|
class TestDetrFrozenBatchNorm2d(nn.Module):
|
|
"""
|
|
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
|
|
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
|
|
torchvision.models.resnet[18,34,50,101] produce nans.
|
|
"""
|
|
|
|
def __init__(self, n):
|
|
super().__init__()
|
|
self.register_buffer("weight", torch.ones(n))
|
|
self.register_buffer("bias", torch.zeros(n))
|
|
self.register_buffer("running_mean", torch.zeros(n))
|
|
self.register_buffer("running_var", torch.ones(n))
|
|
|
|
def _load_from_state_dict(
|
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
):
|
|
num_batches_tracked_key = prefix + "num_batches_tracked"
|
|
if num_batches_tracked_key in state_dict:
|
|
del state_dict[num_batches_tracked_key]
|
|
|
|
super()._load_from_state_dict(
|
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
)
|
|
|
|
def forward(self, x):
|
|
# move reshapes to the beginning
|
|
# to make it user-friendly
|
|
weight = self.weight.reshape(1, -1, 1, 1)
|
|
bias = self.bias.reshape(1, -1, 1, 1)
|
|
running_var = self.running_var.reshape(1, -1, 1, 1)
|
|
running_mean = self.running_mean.reshape(1, -1, 1, 1)
|
|
epsilon = 1e-5
|
|
scale = weight * (running_var + epsilon).rsqrt()
|
|
bias = bias - running_mean * scale
|
|
return x * scale + bias
|
|
|
|
|
|
def replace_batch_norm(model):
|
|
r"""
|
|
Recursively replace all `torch.nn.BatchNorm2d` with `TestDetrFrozenBatchNorm2d`.
|
|
|
|
Args:
|
|
model (torch.nn.Module):
|
|
input model
|
|
"""
|
|
for name, module in model.named_children():
|
|
if isinstance(module, nn.BatchNorm2d):
|
|
new_module = TestDetrFrozenBatchNorm2d(module.num_features)
|
|
|
|
if not module.weight.device == torch.device("meta"):
|
|
new_module.weight.data.copy_(module.weight)
|
|
new_module.bias.data.copy_(module.bias)
|
|
new_module.running_mean.data.copy_(module.running_mean)
|
|
new_module.running_var.data.copy_(module.running_var)
|
|
|
|
model._modules[name] = new_module
|
|
|
|
if len(list(module.children())) > 0:
|
|
replace_batch_norm(module)
|
|
|
|
|
|
class TestDetrConvEncoder(nn.Module):
|
|
"""
|
|
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
|
|
|
|
nn.BatchNorm2d layers are replaced by TestDetrFrozenBatchNorm2d as defined above.
|
|
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
|
|
# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
|
|
if config.use_timm_backbone:
|
|
# We default to values which were previously hard-coded. This enables configurability from the config
|
|
# using backbone arguments, while keeping the default behavior the same.
|
|
requires_backends(self, ["timm"])
|
|
kwargs = getattr(config, "backbone_kwargs", {})
|
|
kwargs = {} if kwargs is None else kwargs.copy()
|
|
out_indices = kwargs.pop("out_indices", (2, 3, 4) if config.num_feature_levels > 1 else (4,))
|
|
num_channels = kwargs.pop("in_chans", config.num_channels)
|
|
if config.dilation:
|
|
kwargs["output_stride"] = kwargs.get("output_stride", 16)
|
|
backbone = create_model(
|
|
config.backbone,
|
|
pretrained=config.use_pretrained_backbone,
|
|
features_only=True,
|
|
out_indices=out_indices,
|
|
in_chans=num_channels,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
backbone = load_backbone(config)
|
|
|
|
# replace batch norm by frozen batch norm
|
|
with torch.no_grad():
|
|
replace_batch_norm(backbone)
|
|
self.model = backbone
|
|
self.intermediate_channel_sizes = (
|
|
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
|
|
)
|
|
|
|
backbone_model_type = None
|
|
if config.backbone is not None:
|
|
backbone_model_type = config.backbone
|
|
elif config.backbone_config is not None:
|
|
backbone_model_type = config.backbone_config.model_type
|
|
else:
|
|
raise ValueError("Either `backbone` or `backbone_config` should be provided in the config")
|
|
|
|
if "resnet" in backbone_model_type:
|
|
for name, parameter in self.model.named_parameters():
|
|
if config.use_timm_backbone:
|
|
if "layer2" not in name and "layer3" not in name and "layer4" not in name:
|
|
parameter.requires_grad_(False)
|
|
else:
|
|
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
|
|
parameter.requires_grad_(False)
|
|
|
|
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
|
|
# send pixel_values through the model to get list of feature maps
|
|
features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
|
|
|
|
out = []
|
|
for feature_map in features:
|
|
# downsample pixel_mask to match shape of corresponding feature_map
|
|
mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
|
|
out.append((feature_map, mask))
|
|
return out
|
|
|
|
|
|
class TestDetrConvModel(nn.Module):
|
|
"""
|
|
This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder.
|
|
"""
|
|
|
|
def __init__(self, conv_encoder, position_embedding):
|
|
super().__init__()
|
|
self.conv_encoder = conv_encoder
|
|
self.position_embedding = position_embedding
|
|
|
|
def forward(self, pixel_values, pixel_mask):
|
|
# send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples
|
|
out = self.conv_encoder(pixel_values, pixel_mask)
|
|
pos = []
|
|
for feature_map, mask in out:
|
|
# position encoding
|
|
pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype))
|
|
|
|
return out, pos
|
|
|
|
|
|
class TestDetrSinePositionEmbedding(nn.Module):
|
|
"""
|
|
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
|
|
need paper, generalized to work on images.
|
|
"""
|
|
|
|
def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None):
|
|
super().__init__()
|
|
self.embedding_dim = embedding_dim
|
|
self.temperature = temperature
|
|
self.normalize = normalize
|
|
if scale is not None and normalize is False:
|
|
raise ValueError("normalize should be True if scale is passed")
|
|
if scale is None:
|
|
scale = 2 * math.pi
|
|
self.scale = scale
|
|
|
|
def forward(self, pixel_values, pixel_mask):
|
|
if pixel_mask is None:
|
|
raise ValueError("No pixel mask provided")
|
|
y_embed = pixel_mask.cumsum(1, dtype=pixel_values.dtype)
|
|
x_embed = pixel_mask.cumsum(2, dtype=pixel_values.dtype)
|
|
if self.normalize:
|
|
eps = 1e-6
|
|
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
|
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
|
|
|
dim_t = torch.arange(self.embedding_dim, dtype=pixel_values.dtype, device=pixel_values.device)
|
|
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
|
|
|
|
pos_x = x_embed[:, :, :, None] / dim_t
|
|
pos_y = y_embed[:, :, :, None] / dim_t
|
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
return pos
|
|
|
|
|
|
class TestDetrLearnedPositionEmbedding(nn.Module):
|
|
"""
|
|
This module learns positional embeddings up to a fixed maximum size.
|
|
"""
|
|
|
|
def __init__(self, embedding_dim=256):
|
|
super().__init__()
|
|
self.row_embeddings = nn.Embedding(50, embedding_dim)
|
|
self.column_embeddings = nn.Embedding(50, embedding_dim)
|
|
|
|
def forward(self, pixel_values, pixel_mask=None):
|
|
height, width = pixel_values.shape[-2:]
|
|
width_values = torch.arange(width, device=pixel_values.device)
|
|
height_values = torch.arange(height, device=pixel_values.device)
|
|
x_emb = self.column_embeddings(width_values)
|
|
y_emb = self.row_embeddings(height_values)
|
|
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
|
|
pos = pos.permute(2, 0, 1)
|
|
pos = pos.unsqueeze(0)
|
|
pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
|
|
return pos
|
|
|
|
|
|
class TestDetrMultiscaleDeformableAttention(nn.Module):
|
|
"""
|
|
Multiscale deformable attention as proposed in Deformable DETR.
|
|
"""
|
|
|
|
def __init__(self, config: TestDetrConfig, num_heads: int, n_points: int):
|
|
super().__init__()
|
|
|
|
self.attn = MultiScaleDeformableAttention()
|
|
|
|
if config.d_model % num_heads != 0:
|
|
raise ValueError(
|
|
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
|
|
)
|
|
dim_per_head = config.d_model // num_heads
|
|
# check if dim_per_head is power of 2
|
|
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
|
|
warnings.warn(
|
|
"You'd better set embed_dim (d_model) in TestDetrMultiscaleDeformableAttention to make the"
|
|
" dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
|
|
" implementation."
|
|
)
|
|
|
|
self.im2col_step = 64
|
|
|
|
self.d_model = config.d_model
|
|
self.n_levels = config.num_feature_levels
|
|
self.n_heads = num_heads
|
|
self.n_points = n_points
|
|
|
|
self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
|
|
self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
|
|
self.value_proj = nn.Linear(config.d_model, config.d_model)
|
|
self.output_proj = nn.Linear(config.d_model, config.d_model)
|
|
|
|
self.disable_custom_kernels = config.disable_custom_kernels
|
|
|
|
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
|
|
return tensor if position_embeddings is None else tensor + position_embeddings
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
position_embeddings: Optional[torch.Tensor] = None,
|
|
reference_points=None,
|
|
spatial_shapes=None,
|
|
spatial_shapes_list=None,
|
|
level_start_index=None,
|
|
output_attentions: bool = False,
|
|
):
|
|
# add position embeddings to the hidden states before projecting to queries and keys
|
|
if position_embeddings is not None:
|
|
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
|
|
|
batch_size, num_queries, _ = hidden_states.shape
|
|
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
|
total_elements = sum(height * width for height, width in spatial_shapes_list)
|
|
if total_elements != sequence_length:
|
|
raise ValueError(
|
|
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
|
)
|
|
|
|
value = self.value_proj(encoder_hidden_states)
|
|
if attention_mask is not None:
|
|
# we invert the attention_mask
|
|
value = value.masked_fill(~attention_mask[..., None], float(0))
|
|
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
|
|
sampling_offsets = self.sampling_offsets(hidden_states).view(
|
|
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
|
|
)
|
|
attention_weights = self.attention_weights(hidden_states).view(
|
|
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
|
|
)
|
|
attention_weights = F.softmax(attention_weights, -1).view(
|
|
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
|
|
)
|
|
# batch_size, num_queries, n_heads, n_levels, n_points, 2
|
|
num_coordinates = reference_points.shape[-1]
|
|
if num_coordinates == 2:
|
|
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
|
|
sampling_locations = (
|
|
reference_points[:, :, None, :, None, :]
|
|
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
|
)
|
|
elif num_coordinates == 4:
|
|
sampling_locations = (
|
|
reference_points[:, :, None, :, None, :2]
|
|
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
|
|
)
|
|
else:
|
|
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
|
|
|
output = self.attn(
|
|
value,
|
|
spatial_shapes,
|
|
spatial_shapes_list,
|
|
level_start_index,
|
|
sampling_locations,
|
|
attention_weights,
|
|
self.im2col_step,
|
|
)
|
|
|
|
output = self.output_proj(output)
|
|
|
|
return output, attention_weights
|
|
|
|
|
|
class TestDetrMultiheadAttention(nn.Module):
|
|
"""
|
|
Multi-headed attention from 'Attention Is All You Need' paper.
|
|
|
|
Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
dropout: float = 0.0,
|
|
bias: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.head_dim = embed_dim // num_heads
|
|
if self.head_dim * num_heads != self.embed_dim:
|
|
raise ValueError(
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
f" {num_heads})."
|
|
)
|
|
self.scaling = self.head_dim**-0.5
|
|
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
|
|
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
|
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
|
|
return tensor if position_embeddings is None else tensor + position_embeddings
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_embeddings: Optional[torch.Tensor] = None,
|
|
output_attentions: bool = False,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
batch_size, target_len, embed_dim = hidden_states.size()
|
|
# add position embeddings to the hidden states before projecting to queries and keys
|
|
if position_embeddings is not None:
|
|
hidden_states_original = hidden_states
|
|
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
|
|
|
# get queries, keys and values
|
|
query_states = self.q_proj(hidden_states) * self.scaling
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
|
|
value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
|
|
|
|
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
|
|
query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
|
|
key_states = key_states.view(*proj_shape)
|
|
value_states = value_states.view(*proj_shape)
|
|
|
|
source_len = key_states.size(1)
|
|
|
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
|
|
|
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
|
|
raise ValueError(
|
|
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
|
|
f" {attn_weights.size()}"
|
|
)
|
|
|
|
# expand attention_mask
|
|
if attention_mask is not None:
|
|
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
|
|
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
|
|
|
if attention_mask is not None:
|
|
if attention_mask.size() != (batch_size, 1, target_len, source_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is"
|
|
f" {attention_mask.size()}"
|
|
)
|
|
attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask
|
|
attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len)
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
|
|
if output_attentions:
|
|
# this operation is a bit awkward, but it's required to
|
|
# make sure that attn_weights keeps its gradient.
|
|
# In order to do so, attn_weights have to reshaped
|
|
# twice and have to be reused in the following
|
|
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
|
|
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
|
|
else:
|
|
attn_weights_reshaped = None
|
|
|
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
|
|
attn_output = torch.bmm(attn_probs, value_states)
|
|
|
|
if attn_output.size() != (
|
|
batch_size * self.num_heads,
|
|
target_len,
|
|
self.head_dim,
|
|
):
|
|
raise ValueError(
|
|
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
|
|
f" {attn_output.size()}"
|
|
)
|
|
|
|
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
|
|
attn_output = attn_output.transpose(1, 2)
|
|
attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output, attn_weights_reshaped
|
|
|
|
|
|
class TestDetrEncoderLayer(nn.Module):
|
|
def __init__(self, config: TestDetrConfig):
|
|
super().__init__()
|
|
self.embed_dim = config.d_model
|
|
self.self_attn = TestDetrMultiscaleDeformableAttention(
|
|
config,
|
|
num_heads=config.encoder_attention_heads,
|
|
n_points=config.encoder_n_points,
|
|
)
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
self.dropout = config.dropout
|
|
self.activation_fn = ACT2FN[config.activation_function]
|
|
self.activation_dropout = config.activation_dropout
|
|
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
position_embeddings: Optional[torch.Tensor] = None,
|
|
reference_points=None,
|
|
spatial_shapes=None,
|
|
spatial_shapes_list=None,
|
|
level_start_index=None,
|
|
output_attentions: bool = False,
|
|
):
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Input to the layer.
|
|
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
|
Attention mask.
|
|
position_embeddings (`torch.FloatTensor`, *optional*):
|
|
Position embeddings, to be added to `hidden_states`.
|
|
reference_points (`torch.FloatTensor`, *optional*):
|
|
Reference points.
|
|
spatial_shapes (`torch.LongTensor`, *optional*):
|
|
Spatial shapes of the backbone feature maps.
|
|
level_start_index (`torch.LongTensor`, *optional*):
|
|
Level start index.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
"""
|
|
residual = hidden_states
|
|
|
|
# Apply Multi-scale Deformable Attention Module on the multi-scale feature maps.
|
|
hidden_states, attn_weights = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=hidden_states,
|
|
encoder_attention_mask=attention_mask,
|
|
position_embeddings=position_embeddings,
|
|
reference_points=reference_points,
|
|
spatial_shapes=spatial_shapes,
|
|
spatial_shapes_list=spatial_shapes_list,
|
|
level_start_index=level_start_index,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
hidden_states = residual + hidden_states
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
|
|
hidden_states = self.fc2(hidden_states)
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
|
|
hidden_states = residual + hidden_states
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
if self.training:
|
|
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
|
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
|
|
return outputs
|
|
|
|
|
|
class TestDetrDecoderLayer(nn.Module):
|
|
def __init__(self, config: TestDetrConfig):
|
|
super().__init__()
|
|
self.embed_dim = config.d_model
|
|
|
|
# self-attention
|
|
self.self_attn = TestDetrMultiheadAttention(
|
|
embed_dim=self.embed_dim,
|
|
num_heads=config.decoder_attention_heads,
|
|
dropout=config.attention_dropout,
|
|
)
|
|
self.dropout = config.dropout
|
|
self.activation_fn = ACT2FN[config.activation_function]
|
|
self.activation_dropout = config.activation_dropout
|
|
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
# cross-attention
|
|
self.encoder_attn = TestDetrMultiscaleDeformableAttention(
|
|
config,
|
|
num_heads=config.decoder_attention_heads,
|
|
n_points=config.decoder_n_points,
|
|
)
|
|
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
# feedforward neural networks
|
|
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
|
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
position_embeddings: Optional[torch.Tensor] = None,
|
|
reference_points=None,
|
|
spatial_shapes=None,
|
|
spatial_shapes_list=None,
|
|
level_start_index=None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
):
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`):
|
|
Input to the layer of shape `(seq_len, batch, embed_dim)`.
|
|
position_embeddings (`torch.FloatTensor`, *optional*):
|
|
Position embeddings that are added to the queries and keys in the self-attention layer.
|
|
reference_points (`torch.FloatTensor`, *optional*):
|
|
Reference points.
|
|
spatial_shapes (`torch.LongTensor`, *optional*):
|
|
Spatial shapes.
|
|
level_start_index (`torch.LongTensor`, *optional*):
|
|
Level start index.
|
|
encoder_hidden_states (`torch.FloatTensor`):
|
|
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
|
|
values.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
"""
|
|
residual = hidden_states
|
|
|
|
# Self Attention
|
|
hidden_states, self_attn_weights = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
position_embeddings=position_embeddings,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
hidden_states = residual + hidden_states
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
second_residual = hidden_states
|
|
|
|
# Cross-Attention
|
|
cross_attn_weights = None
|
|
hidden_states, cross_attn_weights = self.encoder_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=encoder_attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
position_embeddings=position_embeddings,
|
|
reference_points=reference_points,
|
|
spatial_shapes=spatial_shapes,
|
|
spatial_shapes_list=spatial_shapes_list,
|
|
level_start_index=level_start_index,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
hidden_states = second_residual + hidden_states
|
|
|
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
hidden_states = self.fc2(hidden_states)
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
hidden_states = residual + hidden_states
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights, cross_attn_weights)
|
|
|
|
return outputs
|
|
|
|
|
|
class TestDetrPreTrainedModel(PreTrainedModel):
|
|
config_class = TestDetrConfig
|
|
base_model_prefix = "model"
|
|
main_input_name = "pixel_values"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = [
|
|
r"TestDetrConvEncoder",
|
|
r"TestDetrEncoderLayer",
|
|
r"TestDetrDecoderLayer",
|
|
]
|
|
|
|
def _init_weights(self, module):
|
|
std = self.config.init_std
|
|
|
|
if isinstance(module, TestDetrLearnedPositionEmbedding):
|
|
nn.init.uniform_(module.row_embeddings.weight)
|
|
nn.init.uniform_(module.column_embeddings.weight)
|
|
elif isinstance(module, TestDetrMultiscaleDeformableAttention):
|
|
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
|
|
default_dtype = torch.get_default_dtype()
|
|
thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
|
|
2.0 * math.pi / module.n_heads
|
|
)
|
|
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
|
grid_init = (
|
|
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
|
.view(module.n_heads, 1, 1, 2)
|
|
.repeat(1, module.n_levels, module.n_points, 1)
|
|
)
|
|
for i in range(module.n_points):
|
|
grid_init[:, :, i, :] *= i + 1
|
|
with torch.no_grad():
|
|
module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
|
nn.init.constant_(module.attention_weights.weight.data, 0.0)
|
|
nn.init.constant_(module.attention_weights.bias.data, 0.0)
|
|
nn.init.xavier_uniform_(module.value_proj.weight.data)
|
|
nn.init.constant_(module.value_proj.bias.data, 0.0)
|
|
nn.init.xavier_uniform_(module.output_proj.weight.data)
|
|
nn.init.constant_(module.output_proj.bias.data, 0.0)
|
|
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.Embedding):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.padding_idx is not None:
|
|
module.weight.data[module.padding_idx].zero_()
|
|
if hasattr(module, "reference_points") and not self.config.two_stage:
|
|
nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0)
|
|
nn.init.constant_(module.reference_points.bias.data, 0.0)
|
|
if hasattr(module, "level_embed"):
|
|
nn.init.normal_(module.level_embed)
|
|
|
|
|
|
class TestDetrEncoder(TestDetrPreTrainedModel):
|
|
"""
|
|
Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a
|
|
[`TestDetrEncoderLayer`].
|
|
|
|
The encoder updates the flattened multi-scale feature maps through multiple deformable attention layers.
|
|
|
|
Args:
|
|
config: TestDetrConfig
|
|
"""
|
|
|
|
def __init__(self, config: TestDetrConfig):
|
|
super().__init__(config)
|
|
self.gradient_checkpointing = False
|
|
|
|
self.dropout = config.dropout
|
|
self.layers = nn.ModuleList([TestDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
@staticmethod
|
|
def get_reference_points(spatial_shapes, valid_ratios, device):
|
|
"""
|
|
Get reference points for each feature map. Used in decoder.
|
|
|
|
Args:
|
|
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
|
|
Spatial shapes of each feature map.
|
|
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
|
|
Valid ratios of each feature map.
|
|
device (`torch.device`):
|
|
Device on which to create the tensors.
|
|
Returns:
|
|
`torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
|
|
"""
|
|
reference_points_list = []
|
|
for level, (height, width) in enumerate(spatial_shapes):
|
|
ref_y, ref_x = meshgrid(
|
|
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
|
|
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
|
|
indexing="ij",
|
|
)
|
|
# TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
|
|
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
|
|
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width)
|
|
ref = torch.stack((ref_x, ref_y), -1)
|
|
reference_points_list.append(ref)
|
|
reference_points = torch.cat(reference_points_list, 1)
|
|
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
|
return reference_points
|
|
|
|
def forward(
|
|
self,
|
|
inputs_embeds=None,
|
|
attention_mask=None,
|
|
position_embeddings=None,
|
|
spatial_shapes=None,
|
|
spatial_shapes_list=None,
|
|
level_start_index=None,
|
|
valid_ratios=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
r"""
|
|
Args:
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
|
|
- 1 for pixel features that are real (i.e. **not masked**),
|
|
- 0 for pixel features that are padding (i.e. **masked**).
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
Position embeddings that are added to the queries and keys in each self-attention layer.
|
|
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
|
|
Spatial shapes of each feature map.
|
|
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
|
|
Starting index of each feature map.
|
|
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
|
|
Ratio of valid area in each feature level.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
for more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
hidden_states = inputs_embeds
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
|
|
spatial_shapes_tuple = tuple(spatial_shapes_list)
|
|
reference_points = self.get_reference_points(spatial_shapes_tuple, valid_ratios, device=inputs_embeds.device)
|
|
|
|
encoder_states = () if output_hidden_states else None
|
|
all_attentions = () if output_attentions else None
|
|
for i, encoder_layer in enumerate(self.layers):
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
if self.gradient_checkpointing and self.training:
|
|
layer_outputs = self._gradient_checkpointing_func(
|
|
encoder_layer.__call__,
|
|
hidden_states,
|
|
attention_mask,
|
|
position_embeddings,
|
|
reference_points,
|
|
spatial_shapes,
|
|
spatial_shapes_list,
|
|
level_start_index,
|
|
output_attentions,
|
|
)
|
|
else:
|
|
layer_outputs = encoder_layer(
|
|
hidden_states,
|
|
attention_mask,
|
|
position_embeddings=position_embeddings,
|
|
reference_points=reference_points,
|
|
spatial_shapes=spatial_shapes,
|
|
spatial_shapes_list=spatial_shapes_list,
|
|
level_start_index=level_start_index,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if output_attentions:
|
|
all_attentions = all_attentions + (layer_outputs[1],)
|
|
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=encoder_states,
|
|
attentions=all_attentions,
|
|
)
|
|
|
|
|
|
def inverse_sigmoid(x, eps=1e-5):
|
|
x = x.clamp(min=0, max=1)
|
|
x1 = x.clamp(min=eps)
|
|
x2 = (1 - x).clamp(min=eps)
|
|
return torch.log(x1 / x2)
|
|
|
|
|
|
class TestDetrDecoder(TestDetrPreTrainedModel):
|
|
"""
|
|
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TestDetrDecoderLayer`].
|
|
|
|
The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
|
|
|
|
Some tweaks for Deformable DETR:
|
|
|
|
- `position_embeddings`, `reference_points`, `spatial_shapes` and `valid_ratios` are added to the forward pass.
|
|
- it also returns a stack of intermediate outputs and reference points from all decoding layers.
|
|
|
|
Args:
|
|
config: TestDetrConfig
|
|
"""
|
|
|
|
def __init__(self, config: TestDetrConfig):
|
|
super().__init__(config)
|
|
|
|
self.dropout = config.dropout
|
|
self.layers = nn.ModuleList([TestDetrDecoderLayer(config) for _ in range(config.decoder_layers)])
|
|
self.gradient_checkpointing = False
|
|
|
|
# hack implementation for iterative bounding box refinement and two-stage Deformable DETR
|
|
self.bbox_embed = None
|
|
self.class_embed = None
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
inputs_embeds=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
position_embeddings=None,
|
|
reference_points=None,
|
|
spatial_shapes=None,
|
|
spatial_shapes_list=None,
|
|
level_start_index=None,
|
|
valid_ratios=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
r"""
|
|
Args:
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
|
The query embeddings that are passed into the decoder.
|
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
|
of the decoder.
|
|
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
|
|
in `[0, 1]`:
|
|
- 1 for pixels that are real (i.e. **not masked**),
|
|
- 0 for pixels that are padding (i.e. **masked**).
|
|
position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
Position embeddings that are added to the queries and keys in each self-attention layer.
|
|
reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
|
|
Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
|
|
spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
|
|
Spatial shapes of the feature maps.
|
|
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
|
|
Indexes for the start of each feature level. In range `[0, sequence_length]`.
|
|
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*):
|
|
Ratio of valid area in each feature level.
|
|
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
for more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
intermediate = ()
|
|
intermediate_reference_points = ()
|
|
|
|
for idx, decoder_layer in enumerate(self.layers):
|
|
num_coordinates = reference_points.shape[-1]
|
|
if num_coordinates == 4:
|
|
reference_points_input = (
|
|
reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
|
|
)
|
|
elif reference_points.shape[-1] == 2:
|
|
reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]
|
|
else:
|
|
raise ValueError("Reference points' last dimension must be of size 2")
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
layer_outputs = self._gradient_checkpointing_func(
|
|
decoder_layer.__call__,
|
|
hidden_states,
|
|
position_embeddings,
|
|
reference_points_input,
|
|
spatial_shapes,
|
|
spatial_shapes_list,
|
|
level_start_index,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
output_attentions,
|
|
)
|
|
else:
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
position_embeddings=position_embeddings,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
reference_points=reference_points_input,
|
|
spatial_shapes=spatial_shapes,
|
|
spatial_shapes_list=spatial_shapes_list,
|
|
level_start_index=level_start_index,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
# hack implementation for iterative bounding box refinement
|
|
if self.bbox_embed is not None:
|
|
tmp = self.bbox_embed[idx](hidden_states)
|
|
num_coordinates = reference_points.shape[-1]
|
|
if num_coordinates == 4:
|
|
new_reference_points = tmp + inverse_sigmoid(reference_points)
|
|
new_reference_points = new_reference_points.sigmoid()
|
|
elif num_coordinates == 2:
|
|
new_reference_points = tmp
|
|
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
|
|
new_reference_points = new_reference_points.sigmoid()
|
|
else:
|
|
raise ValueError(
|
|
f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}"
|
|
)
|
|
reference_points = new_reference_points.detach()
|
|
|
|
intermediate += (hidden_states,)
|
|
intermediate_reference_points += (reference_points,)
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
if encoder_hidden_states is not None:
|
|
all_cross_attentions += (layer_outputs[2],)
|
|
|
|
# Keep batch_size as first dimension
|
|
intermediate = torch.stack(intermediate, dim=1)
|
|
intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
|
|
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [
|
|
hidden_states,
|
|
intermediate,
|
|
intermediate_reference_points,
|
|
all_hidden_states,
|
|
all_self_attns,
|
|
all_cross_attentions,
|
|
]
|
|
if v is not None
|
|
)
|
|
return TestDetrDecoderOutput(
|
|
last_hidden_state=hidden_states,
|
|
intermediate_hidden_states=intermediate,
|
|
intermediate_reference_points=intermediate_reference_points,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
cross_attentions=all_cross_attentions,
|
|
)
|
|
|
|
|
|
def build_position_encoding(config):
|
|
n_steps = config.d_model // 2
|
|
if config.position_embedding_type == "sine":
|
|
# TODO find a better way of exposing other arguments
|
|
position_embedding = TestDetrSinePositionEmbedding(n_steps, normalize=True)
|
|
elif config.position_embedding_type == "learned":
|
|
position_embedding = TestDetrLearnedPositionEmbedding(n_steps)
|
|
else:
|
|
raise ValueError(f"Not supported {config.position_embedding_type}")
|
|
|
|
return position_embedding
|
|
|
|
|
|
TEST_DETR_START_DOCSTRING = r"""
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
etc.)
|
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
and behavior.
|
|
|
|
Parameters:
|
|
config ([`TestDetrConfig`]):
|
|
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
|
load the weights associated with the model, only the configuration. Check out the
|
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
"""
|
|
|
|
TEST_DETR_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
Pixel values. Padding will be ignored by default should you provide it.
|
|
|
|
Pixel values can be obtained using [`AutoImageProcessor`]. See [`TestDetrImageProcessor.__call__`]
|
|
for details.
|
|
|
|
pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
|
Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for pixels that are real (i.e. **not masked**),
|
|
- 0 for pixels that are padding (i.e. **masked**).
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
|
|
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
|
|
Not used by default. Can be used to mask object queries.
|
|
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
|
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
|
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
|
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
|
can choose to directly pass a flattened representation of an image.
|
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
|
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
|
embedded representation.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""
|
|
The bare Deformable DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw
|
|
hidden-states without any specific head on top.
|
|
""",
|
|
TEST_DETR_START_DOCSTRING,
|
|
)
|
|
class TestDetrModel(TestDetrPreTrainedModel):
|
|
def __init__(self, config: TestDetrConfig):
|
|
super().__init__(config)
|
|
|
|
# Create backbone + positional encoding
|
|
backbone = TestDetrConvEncoder(config)
|
|
position_embeddings = build_position_encoding(config)
|
|
self.backbone = TestDetrConvModel(backbone, position_embeddings)
|
|
|
|
# Create input projection layers
|
|
if config.num_feature_levels > 1:
|
|
num_backbone_outs = len(backbone.intermediate_channel_sizes)
|
|
input_proj_list = []
|
|
for _ in range(num_backbone_outs):
|
|
in_channels = backbone.intermediate_channel_sizes[_]
|
|
input_proj_list.append(
|
|
nn.Sequential(
|
|
nn.Conv2d(in_channels, config.d_model, kernel_size=1),
|
|
nn.GroupNorm(32, config.d_model),
|
|
)
|
|
)
|
|
for _ in range(config.num_feature_levels - num_backbone_outs):
|
|
input_proj_list.append(
|
|
nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels,
|
|
config.d_model,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
),
|
|
nn.GroupNorm(32, config.d_model),
|
|
)
|
|
)
|
|
in_channels = config.d_model
|
|
self.input_proj = nn.ModuleList(input_proj_list)
|
|
else:
|
|
self.input_proj = nn.ModuleList(
|
|
[
|
|
nn.Sequential(
|
|
nn.Conv2d(
|
|
backbone.intermediate_channel_sizes[-1],
|
|
config.d_model,
|
|
kernel_size=1,
|
|
),
|
|
nn.GroupNorm(32, config.d_model),
|
|
)
|
|
]
|
|
)
|
|
|
|
if not config.two_stage:
|
|
self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model * 2)
|
|
|
|
self.encoder = TestDetrEncoder(config)
|
|
self.decoder = TestDetrDecoder(config)
|
|
|
|
self.level_embed = nn.Parameter(torch.Tensor(config.num_feature_levels, config.d_model))
|
|
|
|
if config.two_stage:
|
|
self.enc_output = nn.Linear(config.d_model, config.d_model)
|
|
self.enc_output_norm = nn.LayerNorm(config.d_model)
|
|
self.pos_trans = nn.Linear(config.d_model * 2, config.d_model * 2)
|
|
self.pos_trans_norm = nn.LayerNorm(config.d_model * 2)
|
|
else:
|
|
self.reference_points = nn.Linear(config.d_model, 2)
|
|
|
|
self.post_init()
|
|
|
|
def get_encoder(self):
|
|
return self.encoder
|
|
|
|
def get_decoder(self):
|
|
return self.decoder
|
|
|
|
def freeze_backbone(self):
|
|
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
|
param.requires_grad_(False)
|
|
|
|
def unfreeze_backbone(self):
|
|
for name, param in self.backbone.conv_encoder.model.named_parameters():
|
|
param.requires_grad_(True)
|
|
|
|
def get_valid_ratio(self, mask, dtype=torch.float32):
|
|
"""Get the valid ratio of all feature maps."""
|
|
|
|
_, height, width = mask.shape
|
|
valid_height = torch.sum(mask[:, :, 0], 1)
|
|
valid_width = torch.sum(mask[:, 0, :], 1)
|
|
valid_ratio_height = valid_height.to(dtype) / height
|
|
valid_ratio_width = valid_width.to(dtype) / width
|
|
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
|
|
return valid_ratio
|
|
|
|
def get_proposal_pos_embed(self, proposals):
|
|
"""Get the position embedding of the proposals."""
|
|
|
|
num_pos_feats = self.config.d_model // 2
|
|
temperature = 10000
|
|
scale = 2 * math.pi
|
|
|
|
dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device)
|
|
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
|
|
# batch_size, num_queries, 4
|
|
proposals = proposals.sigmoid() * scale
|
|
# batch_size, num_queries, 4, 128
|
|
pos = proposals[:, :, :, None] / dim_t
|
|
# batch_size, num_queries, 4, 64, 2 -> batch_size, num_queries, 512
|
|
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
|
|
return pos
|
|
|
|
def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes):
|
|
"""Generate the encoder output proposals from encoded enc_output.
|
|
|
|
Args:
|
|
enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder.
|
|
padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`.
|
|
spatial_shapes (list[tuple[int, int]]): Spatial shapes of the feature maps.
|
|
|
|
Returns:
|
|
`tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction.
|
|
- object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to
|
|
directly predict a bounding box. (without the need of a decoder)
|
|
- output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals, after an inverse
|
|
sigmoid.
|
|
"""
|
|
batch_size = enc_output.shape[0]
|
|
proposals = []
|
|
_cur = 0
|
|
for level, (height, width) in enumerate(spatial_shapes):
|
|
mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1)
|
|
valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
|
|
valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
|
|
|
|
grid_y, grid_x = meshgrid(
|
|
torch.linspace(
|
|
0,
|
|
height - 1,
|
|
height,
|
|
dtype=enc_output.dtype,
|
|
device=enc_output.device,
|
|
),
|
|
torch.linspace(
|
|
0,
|
|
width - 1,
|
|
width,
|
|
dtype=enc_output.dtype,
|
|
device=enc_output.device,
|
|
),
|
|
indexing="ij",
|
|
)
|
|
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
|
|
|
|
scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2)
|
|
grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale
|
|
width_height = torch.ones_like(grid) * 0.05 * (2.0**level)
|
|
proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4)
|
|
proposals.append(proposal)
|
|
_cur += height * width
|
|
output_proposals = torch.cat(proposals, 1)
|
|
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
|
|
output_proposals = torch.log(output_proposals / (1 - output_proposals)) # inverse sigmoid
|
|
output_proposals = output_proposals.masked_fill(padding_mask.unsqueeze(-1), float("inf"))
|
|
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
|
|
|
|
# assign each pixel as an object query
|
|
object_query = enc_output
|
|
object_query = object_query.masked_fill(padding_mask.unsqueeze(-1), float(0))
|
|
object_query = object_query.masked_fill(~output_proposals_valid, float(0))
|
|
object_query = self.enc_output_norm(self.enc_output(object_query))
|
|
return object_query, output_proposals
|
|
|
|
@add_start_docstrings_to_model_forward(TEST_DETR_INPUTS_DOCSTRING)
|
|
@replace_return_docstrings(output_type=TestDetrModelOutput, config_class=_CONFIG_FOR_DOC)
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
pixel_mask: Optional[torch.LongTensor] = None,
|
|
decoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
encoder_outputs: Optional[torch.FloatTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[tuple[torch.FloatTensor], TestDetrModelOutput]:
|
|
r"""
|
|
Returns:
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import AutoImageProcessor, TestDetrModel
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
>>> image_processor = AutoImageProcessor.from_pretrained("SenseTime/deformable-detr")
|
|
>>> model = TestDetrModel.from_pretrained("SenseTime/deformable-detr")
|
|
|
|
>>> inputs = image_processor(images=image, return_tensors="pt")
|
|
|
|
>>> outputs = model(**inputs)
|
|
|
|
>>> last_hidden_states = outputs.last_hidden_state
|
|
>>> list(last_hidden_states.shape)
|
|
[1, 300, 256]
|
|
```"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
batch_size, num_channels, height, width = pixel_values.shape
|
|
device = pixel_values.device
|
|
|
|
if pixel_mask is None:
|
|
pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device)
|
|
|
|
# Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper)
|
|
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
|
|
# which is a list of tuples
|
|
features, position_embeddings_list = self.backbone(pixel_values, pixel_mask)
|
|
|
|
# Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
|
sources = []
|
|
masks = []
|
|
for level, (source, mask) in enumerate(features):
|
|
sources.append(self.input_proj[level](source))
|
|
masks.append(mask)
|
|
if mask is None:
|
|
raise ValueError("No attention mask was provided")
|
|
|
|
# Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
|
|
if self.config.num_feature_levels > len(sources):
|
|
_len_sources = len(sources)
|
|
for level in range(_len_sources, self.config.num_feature_levels):
|
|
if level == _len_sources:
|
|
source = self.input_proj[level](features[-1][0])
|
|
else:
|
|
source = self.input_proj[level](sources[-1])
|
|
mask = nn.functional.interpolate(pixel_mask[None].to(pixel_values.dtype), size=source.shape[-2:]).to(
|
|
torch.bool
|
|
)[0]
|
|
pos_l = self.backbone.position_embedding(source, mask).to(source.dtype)
|
|
sources.append(source)
|
|
masks.append(mask)
|
|
position_embeddings_list.append(pos_l)
|
|
|
|
# Create queries
|
|
query_embeds = None
|
|
if not self.config.two_stage:
|
|
query_embeds = self.query_position_embeddings.weight
|
|
|
|
# Prepare encoder inputs (by flattening)
|
|
source_flatten = []
|
|
mask_flatten = []
|
|
lvl_pos_embed_flatten = []
|
|
spatial_shapes_list = []
|
|
for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)):
|
|
batch_size, num_channels, height, width = source.shape
|
|
spatial_shape = (height, width)
|
|
spatial_shapes_list.append(spatial_shape)
|
|
source = source.flatten(2).transpose(1, 2)
|
|
mask = mask.flatten(1)
|
|
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
|
lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1)
|
|
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
|
source_flatten.append(source)
|
|
mask_flatten.append(mask)
|
|
source_flatten = torch.cat(source_flatten, 1)
|
|
mask_flatten = torch.cat(mask_flatten, 1)
|
|
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
|
spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
|
|
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
|
valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
|
|
|
|
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
|
|
# Also provide spatial_shapes, level_start_index and valid_ratios
|
|
if encoder_outputs is None:
|
|
encoder_outputs = self.encoder(
|
|
inputs_embeds=source_flatten,
|
|
attention_mask=mask_flatten,
|
|
position_embeddings=lvl_pos_embed_flatten,
|
|
spatial_shapes=spatial_shapes,
|
|
spatial_shapes_list=spatial_shapes_list,
|
|
level_start_index=level_start_index,
|
|
valid_ratios=valid_ratios,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
|
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
|
encoder_outputs = BaseModelOutput(
|
|
last_hidden_state=encoder_outputs[0],
|
|
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
)
|
|
|
|
# Fifth, prepare decoder inputs
|
|
batch_size, _, num_channels = encoder_outputs[0].shape
|
|
enc_outputs_class = None
|
|
enc_outputs_coord_logits = None
|
|
if self.config.two_stage:
|
|
object_query_embedding, output_proposals = self.gen_encoder_output_proposals(
|
|
encoder_outputs[0], ~mask_flatten, spatial_shapes_list
|
|
)
|
|
|
|
# hack implementation for two-stage Deformable DETR
|
|
# apply a detection head to each pixel (A.4 in paper)
|
|
# linear projection for bounding box binary classification (i.e. foreground and background)
|
|
enc_outputs_class = self.decoder.class_embed[-1](object_query_embedding)
|
|
# 3-layer FFN to predict bounding boxes coordinates (bbox regression branch)
|
|
delta_bbox = self.decoder.bbox_embed[-1](object_query_embedding)
|
|
enc_outputs_coord_logits = delta_bbox + output_proposals
|
|
|
|
# only keep top scoring `config.two_stage_num_proposals` proposals
|
|
topk = self.config.two_stage_num_proposals
|
|
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
|
|
topk_coords_logits = torch.gather(
|
|
enc_outputs_coord_logits,
|
|
1,
|
|
topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
|
|
)
|
|
|
|
topk_coords_logits = topk_coords_logits.detach()
|
|
reference_points = topk_coords_logits.sigmoid()
|
|
init_reference_points = reference_points
|
|
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_logits)))
|
|
query_embed, target = torch.split(pos_trans_out, num_channels, dim=2)
|
|
else:
|
|
query_embed, target = torch.split(query_embeds, num_channels, dim=1)
|
|
query_embed = query_embed.unsqueeze(0).expand(batch_size, -1, -1)
|
|
target = target.unsqueeze(0).expand(batch_size, -1, -1)
|
|
reference_points = self.reference_points(query_embed).sigmoid()
|
|
init_reference_points = reference_points
|
|
|
|
decoder_outputs = self.decoder(
|
|
inputs_embeds=target,
|
|
position_embeddings=query_embed,
|
|
encoder_hidden_states=encoder_outputs[0],
|
|
encoder_attention_mask=mask_flatten,
|
|
reference_points=reference_points,
|
|
spatial_shapes=spatial_shapes,
|
|
spatial_shapes_list=spatial_shapes_list,
|
|
level_start_index=level_start_index,
|
|
valid_ratios=valid_ratios,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
if not return_dict:
|
|
enc_outputs = tuple(value for value in [enc_outputs_class, enc_outputs_coord_logits] if value is not None)
|
|
tuple_outputs = (init_reference_points,) + decoder_outputs + encoder_outputs + enc_outputs
|
|
|
|
return tuple_outputs
|
|
|
|
return TestDetrModelOutput(
|
|
init_reference_points=init_reference_points,
|
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
|
|
intermediate_reference_points=decoder_outputs.intermediate_reference_points,
|
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
decoder_attentions=decoder_outputs.attentions,
|
|
cross_attentions=decoder_outputs.cross_attentions,
|
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
encoder_attentions=encoder_outputs.attentions,
|
|
enc_outputs_class=enc_outputs_class,
|
|
enc_outputs_coord_logits=enc_outputs_coord_logits,
|
|
)
|