Add DAB-DETR for object detection (#30803)

* initial commit

* encoder+decoder layer changes WIP

* architecture checks

* working version of detection + segmentation

* fix modeling outputs

* fix return dict + output att/hs

* found the position embedding masking bug

* pre-training version

* added iamge processors

* typo in init.py

* iterupdate set to false

* fixed num_labels in class_output linear layer bias init

* multihead attention shape fixes

* test improvements

* test update

* dab-detr model_doc update

* dab-detr model_doc update2

* test fix:test_retain_grad_hidden_states_attentions

* config file clean and renaming variables

* config file clean and renaming variables fix

* updated convert_to_hf file

* small fixes

* style and qulity checks

* return_dict fix

* Merge branch main into add_dab_detr

* small comment fix

* skip test_inputs_embeds test

* image processor updates + image processor test updates

* check copies test fix update

* updates for check_copies.py test

* updates for check_copies.py test2

* tied weights fix

* fixed image processing tests and fixed shared weights issues

* added numpy nd array option to get_Expected_values method in test_image_processing_dab_detr.py

* delete prints from test file

* SafeTensor modification to solve HF Trainer issue

* removing the safetensor modifications

* make fix copies and hf uplaod has been added.

* fixed index.md

* fixed repo consistency

* styel fix and dabdetrimageprocessor docstring update

* requested modifications after the first review

* Update src/transformers/models/dab_detr/image_processing_dab_detr.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* repo consistency has been fixed

* update copied NestedTensor function after main merge

* Update src/transformers/models/dab_detr/modeling_dab_detr.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* temp commit

* temp commit2

* temp commit 3

* unit tests are fixed

* fixed repo consistency

* updated expected_boxes varible values based on related notebook results in DABDETRIntegrationTests file.

* temporarialy config modifications and repo consistency fixes

* Put dilation parameter back to config

* pattern embeddings have been added to the rename_keys method

* add dilation comment to config + add as an exception in check_config_attributes SPECIAL CASES

* delete FeatureExtractor part from docs.md

* requested modifications in modeling_dab_detr.py

* [run_slow] dab_detr

* deleted last segmentation code part, updated conversion script and changed the hf path in test files

* temp commit of requested modifications

* temp commit of requested modifications 2

* updated config file, resolved codepaths and refactored conversion script

* updated decodelayer block types and refactored conversion script

* style and quality update

* small modifications based on the request

* attentions are refactored

* removed loss functions from modeling file, added loss function to lossutils, tried to move the MLP layer generation to config but it failed

* deleted imageprocessor

* fixed conversion script + quality and style

* fixed config_att

* [run_slow] dab_detr

* changing model path in conversion file and in test file

* fix Decoder variable naming

* testing the old loss function

* switched back to the new loss function and testing with the odl attention functions

* switched back to the new last good result modeling file

* moved back to the version when I asked the review

* missing new line at the end of the file

* old version test

* turn back to newest mdoel versino but change image processor

* style fix

* style fix after merge main

* [run_slow] dab_detr

* [run_slow] dab_detr

* added device and type for head bias data part

* [run_slow] dab_detr

* fixed model head bias data fill

* changed test_inference_object_detection_head assertTrues to torch test assert_close

* fixes part 1

* quality update

* self.bbox_embed in decoder has been restored

* changed Assert true torch closeall methods to torch testing assertclose

* modelcard markdown file has been updated

* deleted intemediate list from decoder module

---------

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
David 2025-02-04 18:28:27 +01:00 committed by GitHub
parent fe52679e74
commit 8d73a38606
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 3259 additions and 2 deletions

View File

@ -643,6 +643,8 @@
title: ConvNeXTV2
- local: model_doc/cvt
title: CvT
- local: model_doc/dab-detr
title: DAB-DETR
- local: model_doc/deformable_detr
title: Deformable DETR
- local: model_doc/deit

View File

@ -110,6 +110,7 @@ Flax), PyTorch, and/or TensorFlow.
| [CPM-Ant](model_doc/cpmant) | ✅ | ❌ | ❌ |
| [CTRL](model_doc/ctrl) | ✅ | ✅ | ❌ |
| [CvT](model_doc/cvt) | ✅ | ✅ | ❌ |
| [DAB-DETR](model_doc/dab-detr) | ✅ | ❌ | ❌ |
| [DAC](model_doc/dac) | ✅ | ❌ | ❌ |
| [Data2VecAudio](model_doc/data2vec) | ✅ | ❌ | ❌ |
| [Data2VecText](model_doc/data2vec) | ✅ | ❌ | ❌ |

View File

@ -0,0 +1,119 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# DAB-DETR
## Overview
The DAB-DETR model was proposed in [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](https://arxiv.org/abs/2201.12329) by Shilong Liu, Feng Li, Hao Zhang, Xiao Yang, Xianbiao Qi, Hang Su, Jun Zhu, Lei Zhang.
DAB-DETR is an enhanced variant of Conditional DETR. It utilizes dynamically updated anchor boxes to provide both a reference query point (x, y) and a reference anchor size (w, h), improving cross-attention computation. This new approach achieves 45.7% AP when trained for 50 epochs with a single ResNet-50 model as the backbone.
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dab_detr_convergence_plot.png"
alt="drawing" width="600"/>
The abstract from the paper is the following:
*We present in this paper a novel query formulation using dynamic anchor boxes
for DETR (DEtection TRansformer) and offer a deeper understanding of the role
of queries in DETR. This new formulation directly uses box coordinates as queries
in Transformer decoders and dynamically updates them layer-by-layer. Using box
coordinates not only helps using explicit positional priors to improve the query-to-feature similarity and eliminate the slow training convergence issue in DETR,
but also allows us to modulate the positional attention map using the box width
and height information. Such a design makes it clear that queries in DETR can be
implemented as performing soft ROI pooling layer-by-layer in a cascade manner.
As a result, it leads to the best performance on MS-COCO benchmark among
the DETR-like detection models under the same setting, e.g., AP 45.7% using
ResNet50-DC5 as backbone trained in 50 epochs. We also conducted extensive
experiments to confirm our analysis and verify the effectiveness of our methods.*
This model was contributed by [davidhajdu](https://huggingface.co/davidhajdu).
The original code can be found [here](https://github.com/IDEA-Research/DAB-DETR).
## How to Get Started with the Model
Use the code below to get started with the model.
```python
import torch
import requests
from PIL import Image
from transformers import AutoModelForObjectDetection, AutoImageProcessor
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab-detr-resnet-50")
model = AutoModelForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50")
inputs = image_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3)
for result in results:
for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
score, label = score.item(), label_id.item()
box = [round(i, 2) for i in box.tolist()]
print(f"{model.config.id2label[label]}: {score:.2f} {box}")
```
This should output
```
cat: 0.87 [14.7, 49.39, 320.52, 469.28]
remote: 0.86 [41.08, 72.37, 173.39, 117.2]
cat: 0.86 [344.45, 19.43, 639.85, 367.86]
remote: 0.61 [334.27, 75.93, 367.92, 188.81]
couch: 0.59 [-0.04, 1.34, 639.9, 477.09]
```
There are three other ways to instantiate a DAB-DETR model (depending on what you prefer):
Option 1: Instantiate DAB-DETR with pre-trained weights for entire model
```py
>>> from transformers import DabDetrForObjectDetection
>>> model = DabDetrForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50")
```
Option 2: Instantiate DAB-DETR with randomly initialized weights for Transformer, but pre-trained weights for backbone
```py
>>> from transformers import DabDetrConfig, DabDetrForObjectDetection
>>> config = DabDetrConfig()
>>> model = DabDetrForObjectDetection(config)
```
Option 3: Instantiate DAB-DETR with randomly initialized weights for backbone + Transformer
```py
>>> config = DabDetrConfig(use_pretrained_backbone=False)
>>> model = DabDetrForObjectDetection(config)
```
## DabDetrConfig
[[autodoc]] DabDetrConfig
## DabDetrModel
[[autodoc]] DabDetrModel
- forward
## DabDetrForObjectDetection
[[autodoc]] DabDetrForObjectDetection
- forward

View File

@ -328,6 +328,7 @@ _import_structure = {
"CTRLTokenizer",
],
"models.cvt": ["CvtConfig"],
"models.dab_detr": ["DabDetrConfig"],
"models.dac": ["DacConfig", "DacFeatureExtractor"],
"models.data2vec": [
"Data2VecAudioConfig",
@ -1898,6 +1899,13 @@ else:
"CvtPreTrainedModel",
]
)
_import_structure["models.dab_detr"].extend(
[
"DabDetrForObjectDetection",
"DabDetrModel",
"DabDetrPreTrainedModel",
]
)
_import_structure["models.dac"].extend(
[
"DacModel",
@ -5387,6 +5395,9 @@ if TYPE_CHECKING:
CTRLTokenizer,
)
from .models.cvt import CvtConfig
from .models.dab_detr import (
DabDetrConfig,
)
from .models.dac import (
DacConfig,
DacFeatureExtractor,
@ -6926,6 +6937,11 @@ if TYPE_CHECKING:
CvtModel,
CvtPreTrainedModel,
)
from .models.dab_detr import (
DabDetrForObjectDetection,
DabDetrModel,
DabDetrPreTrainedModel,
)
from .models.dac import (
DacModel,
DacPreTrainedModel,

View File

@ -217,6 +217,7 @@ ACT2CLS = {
"silu": nn.SiLU,
"swish": nn.SiLU,
"tanh": nn.Tanh,
"prelu": nn.PReLU,
}
ACT2FN = ClassInstantier(ACT2CLS)

View File

@ -128,6 +128,7 @@ LOSS_MAPPING = {
"ForObjectDetection": ForObjectDetectionLoss,
"DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
"ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
"DabDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
"GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss,
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,

View File

@ -63,6 +63,7 @@ from . import (
cpmant,
ctrl,
cvt,
dab_detr,
dac,
data2vec,
dbrx,

View File

@ -79,6 +79,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("cpmant", "CpmAntConfig"),
("ctrl", "CTRLConfig"),
("cvt", "CvtConfig"),
("dab-detr", "DabDetrConfig"),
("dac", "DacConfig"),
("data2vec-audio", "Data2VecAudioConfig"),
("data2vec-text", "Data2VecTextConfig"),
@ -399,6 +400,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("cpmant", "CPM-Ant"),
("ctrl", "CTRL"),
("cvt", "CvT"),
("dab-detr", "DAB-DETR"),
("dac", "DAC"),
("data2vec-audio", "Data2VecAudio"),
("data2vec-text", "Data2VecText"),

View File

@ -78,6 +78,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("cpmant", "CpmAntModel"),
("ctrl", "CTRLModel"),
("cvt", "CvtModel"),
("dab-detr", "DabDetrModel"),
("dac", "DacModel"),
("data2vec-audio", "Data2VecAudioModel"),
("data2vec-text", "Data2VecTextModel"),
@ -592,6 +593,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
("conditional_detr", "ConditionalDetrModel"),
("convnext", "ConvNextModel"),
("convnextv2", "ConvNextV2Model"),
("dab-detr", "DabDetrModel"),
("data2vec-vision", "Data2VecVisionModel"),
("deformable_detr", "DeformableDetrModel"),
("deit", "DeiTModel"),
@ -890,6 +892,7 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
# Model for Object Detection mapping
("conditional_detr", "ConditionalDetrForObjectDetection"),
("dab-detr", "DabDetrForObjectDetection"),
("deformable_detr", "DeformableDetrForObjectDetection"),
("deta", "DetaForObjectDetection"),
("detr", "DetrForObjectDetection"),

View File

@ -52,7 +52,7 @@ class ConditionalDetrConfig(PretrainedConfig):
Number of object queries, i.e. detection slots. This is the maximal number of objects
[`ConditionalDetrModel`] can detect in a single image. For COCO, we recommend 100 queries.
d_model (`int`, *optional*, defaults to 256):
Dimension of the layers.
This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others.
encoder_layers (`int`, *optional*, defaults to 6):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 6):

View File

@ -74,6 +74,8 @@ class ConditionalDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
layernorm.
reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`):
Reference points (reference points of each layer of the decoder).
"""
intermediate_hidden_states: Optional[torch.FloatTensor] = None
@ -116,6 +118,8 @@ class ConditionalDetrModelOutput(Seq2SeqModelOutput):
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
layernorm.
reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`):
Reference points (reference points of each layer of the decoder).
"""
intermediate_hidden_states: Optional[torch.FloatTensor] = None

View File

@ -0,0 +1,28 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_dab_detr import *
from .modeling_dab_detr import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,260 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DAB-DETR model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class DabDetrConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DabDetrModel`]. It is used to instantiate
a DAB-DETR model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the DAB-DETR
[IDEA-Research/dab_detr-base](https://huggingface.co/IDEA-Research/dab_detr-base) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
use_timm_backbone (`bool`, *optional*, defaults to `True`):
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
API.
backbone_config (`PretrainedConfig` or `dict`, *optional*):
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
case it will default to `ResNetConfig()`.
backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
num_queries (`int`, *optional*, defaults to 300):
Number of object queries, i.e. detection slots. This is the maximal number of objects
[`DabDetrModel`] can detect in a single image. For COCO, we recommend 100 queries.
encoder_layers (`int`, *optional*, defaults to 6):
Number of encoder layers.
encoder_ffn_dim (`int`, *optional*, defaults to 2048):
Dimension of the "intermediate" (often named feed-forward) layer in encoder.
encoder_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_layers (`int`, *optional*, defaults to 6):
Number of decoder layers.
decoder_ffn_dim (`int`, *optional*, defaults to 2048):
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
decoder_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
Indicates whether the transformer model architecture is an encoder-decoder or not.
activation_function (`str` or `function`, *optional*, defaults to `"prelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
hidden_size (`int`, *optional*, defaults to 256):
This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
init_xavier_std (`float`, *optional*, defaults to 1.0):
The scaling factor used for the Xavier initialization gain in the HM Attention map module.
auxiliary_loss (`bool`, *optional*, defaults to `False`):
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when `use_timm_backbone` = `True`.
class_cost (`float`, *optional*, defaults to 2):
Relative weight of the classification error in the Hungarian matching cost.
bbox_cost (`float`, *optional*, defaults to 5):
Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
giou_cost (`float`, *optional*, defaults to 2):
Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
cls_loss_coefficient (`float`, *optional*, defaults to 2):
Relative weight of the classification loss in the object detection loss function.
bbox_loss_coefficient (`float`, *optional*, defaults to 5):
Relative weight of the L1 bounding box loss in the object detection loss.
giou_loss_coefficient (`float`, *optional*, defaults to 2):
Relative weight of the generalized IoU loss in the object detection loss.
focal_alpha (`float`, *optional*, defaults to 0.25):
Alpha parameter in the focal loss.
temperature_height (`int`, *optional*, defaults to 20):
Temperature parameter to tune the flatness of positional attention (HEIGHT)
temperature_width (`int`, *optional*, defaults to 20):
Temperature parameter to tune the flatness of positional attention (WIDTH)
query_dim (`int`, *optional*, defaults to 4):
Query dimension parameter represents the size of the output vector.
random_refpoints_xy (`bool`, *optional*, defaults to `False`):
Whether to fix the x and y coordinates of the anchor boxes with random initialization.
keep_query_pos (`bool`, *optional*, defaults to `False`):
Whether to concatenate the projected positional embedding from the object query into the original query (key) in every decoder layer.
num_patterns (`int`, *optional*, defaults to 0):
Number of pattern embeddings.
normalize_before (`bool`, *optional*, defaults to `False`):
Whether we use a normalization layer in the Encoder or not.
sine_position_embedding_scale (`float`, *optional*, defaults to 'None'):
Scaling factor applied to the normalized positional encodings.
initializer_bias_prior_prob (`float`, *optional*):
The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
Examples:
```python
>>> from transformers import DabDetrConfig, DabDetrModel
>>> # Initializing a DAB-DETR IDEA-Research/dab_detr-base style configuration
>>> configuration = DabDetrConfig()
>>> # Initializing a model (with random weights) from the IDEA-Research/dab_detr-base style configuration
>>> model = DabDetrModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "dab-detr"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_attention_heads": "encoder_attention_heads",
}
def __init__(
self,
use_timm_backbone=True,
backbone_config=None,
backbone="resnet50",
use_pretrained_backbone=True,
backbone_kwargs=None,
num_queries=300,
encoder_layers=6,
encoder_ffn_dim=2048,
encoder_attention_heads=8,
decoder_layers=6,
decoder_ffn_dim=2048,
decoder_attention_heads=8,
is_encoder_decoder=True,
activation_function="prelu",
hidden_size=256,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
init_xavier_std=1.0,
auxiliary_loss=False,
dilation=False,
class_cost=2,
bbox_cost=5,
giou_cost=2,
cls_loss_coefficient=2,
bbox_loss_coefficient=5,
giou_loss_coefficient=2,
focal_alpha=0.25,
temperature_height=20,
temperature_width=20,
query_dim=4,
random_refpoints_xy=False,
keep_query_pos=False,
num_patterns=0,
normalize_before=False,
sine_position_embedding_scale=None,
initializer_bias_prior_prob=None,
**kwargs,
):
if query_dim != 4:
raise ValueError("The query dimensions has to be 4.")
# We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
backbone_kwargs["in_chans"] = 3 # num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
backbone = None
# set timm attributes to None
dilation = None
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_queries = num_queries
self.hidden_size = hidden_size
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.init_xavier_std = init_xavier_std
self.num_hidden_layers = encoder_layers
self.auxiliary_loss = auxiliary_loss
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.backbone_kwargs = backbone_kwargs
# Hungarian matcher
self.class_cost = class_cost
self.bbox_cost = bbox_cost
self.giou_cost = giou_cost
# Loss coefficients
self.cls_loss_coefficient = cls_loss_coefficient
self.bbox_loss_coefficient = bbox_loss_coefficient
self.giou_loss_coefficient = giou_loss_coefficient
self.focal_alpha = focal_alpha
self.query_dim = query_dim
self.random_refpoints_xy = random_refpoints_xy
self.keep_query_pos = keep_query_pos
self.num_patterns = num_patterns
self.normalize_before = normalize_before
self.temperature_width = temperature_width
self.temperature_height = temperature_height
self.sine_position_embedding_scale = sine_position_embedding_scale
self.initializer_bias_prior_prob = initializer_bias_prior_prob
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
__all__ = ["DabDetrConfig"]

View File

@ -0,0 +1,233 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert DAB-DETR checkpoints."""
import argparse
import gc
import json
import re
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download
from transformers import ConditionalDetrImageProcessor, DabDetrConfig, DabDetrForObjectDetection
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
# for dab-DETR, also convert reference point head and query scale MLP
r"input_proj\.(bias|weight)": r"input_projection.\1",
r"refpoint_embed\.weight": r"query_refpoint_embeddings.weight",
r"class_embed\.(bias|weight)": r"class_embed.\1",
# negative lookbehind because of the overlap
r"(?<!transformer\.decoder\.)bbox_embed\.layers\.(\d+)\.(bias|weight)": r"bbox_predictor.layers.\1.\2",
r"transformer\.encoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"encoder.query_scale.layers.\1.\2",
r"transformer\.decoder\.bbox_embed\.layers\.(\d+)\.(bias|weight)": r"decoder.bbox_embed.layers.\1.\2",
r"transformer\.decoder\.norm\.(bias|weight)": r"decoder.layernorm.\1",
r"transformer\.decoder\.ref_point_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_point_head.layers.\1.\2",
r"transformer\.decoder\.ref_anchor_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_anchor_head.layers.\1.\2",
r"transformer\.decoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"decoder.query_scale.layers.\1.\2",
r"transformer\.decoder\.layers\.0\.ca_qpos_proj\.(bias|weight)": r"decoder.layers.0.cross_attn.cross_attn_query_pos_proj.\1",
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + activation function
# output projection
r"transformer\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"encoder.layers.\1.self_attn.out_proj.\2",
# FFN layers
r"transformer\.encoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"encoder.layers.\1.fc\2.\3",
# normalization layers
# nm1
r"transformer\.encoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"encoder.layers.\1.self_attn_layer_norm.\2",
# nm2
r"transformer\.encoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"encoder.layers.\1.final_layer_norm.\2",
# activation function weight
r"transformer\.encoder\.layers\.(\d+)\.activation\.weight": r"encoder.layers.\1.activation_fn.weight",
#########################################################################################################################################
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + activiation function weight
r"transformer\.decoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn.output_proj.\2",
r"transformer\.decoder\.layers\.(\d+)\.cross_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn.output_proj.\2",
# FFNs
r"transformer\.decoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"decoder.layers.\1.mlp.fc\2.\3",
# nm1
r"transformer\.decoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_layer_norm.\2",
# nm2
r"transformer\.decoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_layer_norm.\2",
# nm3
r"transformer\.decoder\.layers\.(\d+)\.norm3\.(bias|weight)": r"decoder.layers.\1.mlp.final_layer_norm.\2",
# activation function weight
r"transformer\.decoder\.layers\.(\d+)\.activation\.weight": r"decoder.layers.\1.mlp.activation_fn.weight",
# q, k, v projections and biases in self-attention in decoder
r"transformer\.decoder\.layers\.(\d+)\.sa_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_content_proj.\2",
r"transformer\.decoder\.layers\.(\d+)\.sa_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_content_proj.\2",
r"transformer\.decoder\.layers\.(\d+)\.sa_qpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_pos_proj.\2",
r"transformer\.decoder\.layers\.(\d+)\.sa_kpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_pos_proj.\2",
r"transformer\.decoder\.layers\.(\d+)\.sa_v_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_value_proj.\2",
# q, k, v projections in cross-attention in decoder
r"transformer\.decoder\.layers\.(\d+)\.ca_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_content_proj.\2",
r"transformer\.decoder\.layers\.(\d+)\.ca_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_content_proj.\2",
r"transformer\.decoder\.layers\.(\d+)\.ca_kpos_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_pos_proj.\2",
r"transformer\.decoder\.layers\.(\d+)\.ca_v_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_value_proj.\2",
r"transformer\.decoder\.layers\.(\d+)\.ca_qpos_sine_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_pos_sine_proj.\2",
}
# Copied from transformers.models.mllama.convert_mllama_weights_to_hf.convert_old_keys_to_new_keys
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
"""
This function should be applied only once, on the concatenated keys to efficiently rename using
the key mappings.
"""
output_dict = {}
if state_dict_keys is not None:
old_text = "\n".join(state_dict_keys)
new_text = old_text
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
if replacement is None:
new_text = re.sub(pattern, "", new_text) # an empty line
continue
new_text = re.sub(pattern, replacement, new_text)
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
return output_dict
def write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub):
logger.info("Converting image processor...")
format = "coco_detection"
image_processor = ConditionalDetrImageProcessor(format=format)
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
image_processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
image_processor.push_to_hub(repo_id=model_name, commit_message="Add new image processor")
@torch.no_grad()
def write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub):
# load modified config. Why? After loading the default config, the backbone kwargs are already set.
if "dc5" in model_name:
config = DabDetrConfig(dilation=True)
else:
# load default config
config = DabDetrConfig()
# set other attributes
if "dab-detr-resnet-50-dc5" == model_name:
config.temperature_height = 10
config.temperature_width = 10
if "fixxy" in model_name:
config.random_refpoints_xy = True
if "pat3" in model_name:
config.num_patterns = 3
# only when the number of patterns (num_patterns parameter in config) are more than 0 like r50-pat3 or r50dc5-pat3
ORIGINAL_TO_CONVERTED_KEY_MAPPING.update({r"transformer.patterns.weight": r"patterns.weight"})
config.num_labels = 91
repo_id = "huggingface/label-files"
filename = "coco-detection-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
# load original model from local path
loaded = torch.load(pretrained_model_weights_path, map_location=torch.device("cpu"))["model"]
# Renaming the original model state dictionary to HF compatibile
all_keys = list(loaded.keys())
new_keys = convert_old_keys_to_new_keys(all_keys)
state_dict = {}
for key in all_keys:
if "backbone.0.body" in key:
new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model._backbone")
state_dict[new_key] = loaded[key]
# Q, K, V encoder values mapping
elif re.search("self_attn.in_proj_(weight|bias)", key):
# Dynamically find the layer number
pattern = r"layers\.(\d+)\.self_attn\.in_proj_(weight|bias)"
match = re.search(pattern, key)
if match:
layer_num = match.group(1)
else:
raise ValueError(f"Pattern not found in key: {key}")
in_proj_value = loaded.pop(key)
if "weight" in key:
state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.weight"] = in_proj_value[:256, :]
state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.weight"] = in_proj_value[256:512, :]
state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.weight"] = in_proj_value[-256:, :]
elif "bias" in key:
state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.bias"] = in_proj_value[:256]
state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.bias"] = in_proj_value[256:512]
state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.bias"] = in_proj_value[-256:]
else:
new_key = new_keys[key]
state_dict[new_key] = loaded[key]
del loaded
gc.collect()
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
prefix = "model."
for key in state_dict.copy().keys():
if not key.startswith("class_embed") and not key.startswith("bbox_predictor"):
val = state_dict.pop(key)
state_dict[prefix + key] = val
# finally, create HuggingFace model and load state dict
model = DabDetrForObjectDetection(config)
model.load_state_dict(state_dict)
model.eval()
logger.info(f"Saving PyTorch model to {pytorch_dump_folder_path}...")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
model.push_to_hub(repo_id=model_name, commit_message="Add new model")
def convert_dab_detr_checkpoint(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub):
logger.info("Converting image processor...")
write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub)
logger.info(f"Converting model {model_name}...")
write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default="dab-detr-resnet-50",
type=str,
help="Name of the DAB_DETR model you'd like to convert.",
)
parser.add_argument(
"--pretrained_model_weights_path",
default="modelzoo/R50/checkpoint.pth",
type=str,
help="The path of the original model weights like: modelzoo/checkpoint.pth",
)
parser.add_argument(
"--pytorch_dump_folder_path", default="DAB_DETR", type=str, help="Path to the folder to output PyTorch model."
)
parser.add_argument(
"--push_to_hub",
default=True,
type=bool,
help="Whether to upload the converted weights and image processor config to the HuggingFace model profile. Default is set to false.",
)
args = parser.parse_args()
convert_dab_detr_checkpoint(
args.model_name, args.pretrained_model_weights_path, args.pytorch_dump_folder_path, args.push_to_hub
)

File diff suppressed because it is too large Load Diff

View File

@ -52,7 +52,7 @@ class DetrConfig(PretrainedConfig):
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can
detect in a single image. For COCO, we recommend 100 queries.
d_model (`int`, *optional*, defaults to 256):
Dimension of the layers.
This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others.
encoder_layers (`int`, *optional*, defaults to 6):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 6):

View File

@ -2482,6 +2482,27 @@ class CvtPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class DabDetrForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DabDetrModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DabDetrPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DacModel(metaclass=DummyObject):
_backends = ["torch"]

View File

View File

@ -0,0 +1,839 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch DAB-DETR model."""
import inspect
import math
import unittest
from typing import Dict, List, Tuple
from transformers import DabDetrConfig, ResNetConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_timm, require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
import torch.nn.functional as F
from transformers import (
DabDetrForObjectDetection,
DabDetrModel,
)
if is_vision_available():
from PIL import Image
from transformers import ConditionalDetrImageProcessor
class DabDetrModelTester:
def __init__(
self,
parent,
batch_size=8,
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
num_queries=12,
num_channels=3,
min_size=200,
max_size=200,
n_targets=8,
num_labels=91,
):
self.parent = parent
self.batch_size = batch_size
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.num_queries = num_queries
self.num_channels = num_channels
self.min_size = min_size
self.max_size = max_size
self.n_targets = n_targets
self.num_labels = num_labels
# we also set the expected seq length for both encoder and decoder
self.encoder_seq_length = math.ceil(self.min_size / 32) * math.ceil(self.max_size / 32)
self.decoder_seq_length = self.num_queries
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size])
pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device)
labels = None
if self.use_labels:
# labels is a list of Dict (each Dict being the labels for a given example in the batch)
labels = []
for i in range(self.batch_size):
target = {}
target["class_labels"] = torch.randint(
high=self.num_labels, size=(self.n_targets,), device=torch_device
)
target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device)
target["masks"] = torch.rand(self.n_targets, self.min_size, self.max_size, device=torch_device)
labels.append(target)
config = self.get_config()
return config, pixel_values, pixel_mask, labels
def get_config(self):
resnet_config = ResNetConfig(
num_channels=3,
embeddings_size=10,
hidden_sizes=[10, 20, 30, 40],
depths=[1, 1, 2, 1],
hidden_act="relu",
num_labels=3,
out_features=["stage2", "stage3", "stage4"],
out_indices=[2, 3, 4],
)
return DabDetrConfig(
hidden_size=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=self.intermediate_size,
decoder_ffn_dim=self.intermediate_size,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
num_queries=self.num_queries,
num_labels=self.num_labels,
use_timm_backbone=False,
backbone_config=resnet_config,
backbone=None,
use_pretrained_backbone=False,
)
def prepare_config_and_inputs_for_common(self):
config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs()
inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
return config, inputs_dict
def create_and_check_dab_detr_model(self, config, pixel_values, pixel_mask, labels):
model = DabDetrModel(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.decoder_seq_length, self.hidden_size)
)
def create_and_check_dab_detr_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
model = DabDetrForObjectDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
result = model(pixel_values=pixel_values, labels=labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
@require_torch
class DabDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(
DabDetrModel,
DabDetrForObjectDetection,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{
"image-feature-extraction": DabDetrModel,
"object-detection": DabDetrForObjectDetection,
}
if is_torch_available()
else {}
)
is_encoder_decoder = True
test_torchscript = False
test_pruning = False
test_head_masking = False
test_missing_keys = False
zero_init_hidden_state = True
# special case for head models
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class.__name__ in ["DabDetrForObjectDetection"]:
labels = []
for i in range(self.model_tester.batch_size):
target = {}
target["class_labels"] = torch.ones(
size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long
)
target["boxes"] = torch.ones(
self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float
)
target["masks"] = torch.ones(
self.model_tester.n_targets,
self.model_tester.min_size,
self.model_tester.max_size,
device=torch_device,
dtype=torch.float,
)
labels.append(target)
inputs_dict["labels"] = labels
return inputs_dict
def setUp(self):
self.model_tester = DabDetrModelTester(self)
self.config_tester = ConfigTester(self, config_class=DabDetrConfig, has_text_modality=False)
def test_config(self):
self.config_tester.run_common_tests()
def test_dab_detr_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dab_detr_model(*config_and_inputs)
def test_dab_detr_object_detection_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dab_detr_object_detection_head_model(*config_and_inputs)
# TODO: check if this works again for PyTorch 2.x.y
@unittest.skip(reason="Got `CUDA error: misaligned address` with PyTorch 2.0.0.")
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip(reason="DETR does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="DETR does not use inputs_embeds")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="DETR does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="DETR does not have a get_input_embeddings method")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="DETR is not a generative model")
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="DETR does not use token embeddings")
def test_resize_tokens_embeddings(self):
pass
@slow
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
print(t)
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
torch.testing.assert_close(
set_nan_tensor_to_zero(tuple_object),
set_nan_tensor_to_zero(dict_object),
atol=1e-5,
rtol=1e-5,
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
if self.has_attentions:
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
if hasattr(self.model_tester, "encoder_seq_length"):
seq_length = self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
seq_length = seq_length * self.model_tester.chunk_length
else:
seq_length = self.model_tester.seq_length
self.assertListEqual(
[hidden_states[0].shape[1], hidden_states[0].shape[2]],
[seq_length, self.model_tester.hidden_size],
)
if config.is_encoder_decoder:
hidden_states = outputs.decoder_hidden_states
self.assertIsInstance(hidden_states, (list, tuple))
self.assertEqual(len(hidden_states), expected_num_layers)
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
self.assertListEqual(
[hidden_states[0].shape[1], hidden_states[0].shape[2]],
[decoder_seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# Had to modify the threshold to 2 decimals instead of 3 because sometimes it threw an error
def test_batching_equivalence(self):
"""
Tests that the model supports batching and that the output is the nearly the same for the same input in
different batch sizes.
(Why "nearly the same" not "exactly the same"? Batching uses different matmul shapes, which often leads to
different results: https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535)
"""
def get_tensor_equivalence_function(batched_input):
# models operating on continuous spaces have higher abs difference than LMs
# instead, we can rely on cos distance for image/speech models, similar to `diffusers`
if "input_ids" not in batched_input:
return lambda tensor1, tensor2: (
1.0 - F.cosine_similarity(tensor1.float().flatten(), tensor2.float().flatten(), dim=0, eps=1e-38)
)
return lambda tensor1, tensor2: torch.max(torch.abs(tensor1 - tensor2))
def recursive_check(batched_object, single_row_object, model_name, key):
if isinstance(batched_object, (list, tuple)):
for batched_object_value, single_row_object_value in zip(batched_object, single_row_object):
recursive_check(batched_object_value, single_row_object_value, model_name, key)
elif isinstance(batched_object, dict):
for batched_object_value, single_row_object_value in zip(
batched_object.values(), single_row_object.values()
):
recursive_check(batched_object_value, single_row_object_value, model_name, key)
# do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects
elif batched_object is None or not isinstance(batched_object, torch.Tensor):
return
elif batched_object.dim() == 0:
return
else:
# indexing the first element does not always work
# e.g. models that output similarity scores of size (N, M) would need to index [0, 0]
slice_ids = [slice(0, index) for index in single_row_object.shape]
batched_row = batched_object[slice_ids]
self.assertFalse(
torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}"
)
self.assertFalse(
torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}"
)
self.assertFalse(
torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}"
)
self.assertFalse(
torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}"
)
self.assertTrue(
(equivalence(batched_row, single_row_object)) <= 1e-02,
msg=(
f"Batched and Single row outputs are not equal in {model_name} for key={key}. "
f"Difference={equivalence(batched_row, single_row_object)}."
),
)
config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
equivalence = get_tensor_equivalence_function(batched_input)
for model_class in self.all_model_classes:
config.output_hidden_states = True
model_name = model_class.__name__
if hasattr(self.model_tester, "prepare_config_and_inputs_for_model_class"):
config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class)
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
model = model_class(config).to(torch_device).eval()
batch_size = self.model_tester.batch_size
single_row_input = {}
for key, value in batched_input_prepared.items():
if isinstance(value, torch.Tensor) and value.shape[0] % batch_size == 0:
# e.g. musicgen has inputs of size (bs*codebooks). in most cases value.shape[0] == batch_size
single_batch_shape = value.shape[0] // batch_size
single_row_input[key] = value[:single_batch_shape]
else:
single_row_input[key] = value
with torch.no_grad():
model_batched_output = model(**batched_input_prepared)
model_row_output = model(**single_row_input)
if isinstance(model_batched_output, torch.Tensor):
model_batched_output = {"model_output": model_batched_output}
model_row_output = {"model_output": model_row_output}
for key in model_batched_output:
# DETR starts from zero-init queries to decoder, leading to cos_similarity = `nan`
if hasattr(self, "zero_init_hidden_state") and "decoder_hidden_states" in key:
model_batched_output[key] = model_batched_output[key][1:]
model_row_output[key] = model_row_output[key][1:]
recursive_check(model_batched_output[key], model_row_output[key], model_name, key)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
decoder_seq_length = self.model_tester.decoder_seq_length
encoder_seq_length = self.model_tester.encoder_seq_length
decoder_key_length = self.model_tester.decoder_seq_length
encoder_key_length = self.model_tester.encoder_seq_length
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
del inputs_dict["output_hidden_states"]
config.output_attentions = True
config.output_hidden_states = False
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
if self.is_encoder_decoder:
correct_outlen = 6
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
decoder_seq_length,
encoder_key_length,
],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
# decoder_hidden_states, encoder_last_hidden_state, encoder_hidden_states
added_hidden_states = 3
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_retain_grad_hidden_states_attentions(self):
# removed retain_grad and grad on decoder_hidden_states, as queries don't require grad
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**inputs, output_attentions=True, output_hidden_states=True)
# logits
output = outputs[0]
encoder_hidden_states = outputs.encoder_hidden_states[0]
encoder_hidden_states.retain_grad()
encoder_attentions = outputs.encoder_attentions[0]
encoder_attentions.retain_grad()
decoder_attentions = outputs.decoder_attentions[0]
decoder_attentions.retain_grad()
cross_attentions = outputs.cross_attentions[0]
cross_attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(encoder_hidden_states.grad)
self.assertIsNotNone(encoder_attentions.grad)
self.assertIsNotNone(decoder_attentions.grad)
self.assertIsNotNone(cross_attentions.grad)
def test_forward_auxiliary_loss(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.auxiliary_loss = True
# only test for object detection and segmentation model
for model_class in self.all_model_classes[1:]:
model = model_class(config)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
outputs = model(**inputs)
self.assertIsNotNone(outputs.auxiliary_outputs)
self.assertEqual(len(outputs.auxiliary_outputs), self.model_tester.num_hidden_layers - 1)
def test_training(self):
if not self.model_tester.is_training:
self.skipTest(reason="ModelTester is not configured to run training tests")
# We only have loss with ObjectDetection
model_class = self.all_model_classes[-1]
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
if model.config.is_encoder_decoder:
expected_arg_names = ["pixel_values", "pixel_mask"]
expected_arg_names.extend(
["head_mask", "decoder_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" in arg_names
else []
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
else:
expected_arg_names = ["pixel_values", "pixel_mask"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_different_timm_backbone(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.backbone_config = None
config.use_timm_backbone = True
config.backbone_kwargs = {"out_indices": [2, 3, 4]}
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if model_class.__name__ == "DabDetrForObjectDetection":
expected_shape = (
self.model_tester.batch_size,
self.model_tester.num_queries,
self.model_tester.num_labels,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)
self.assertTrue(outputs)
def test_initialization(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
configs_no_init.init_xavier_std = 1e9
# Copied from RT-DETR
configs_no_init.initializer_bias_prior_prob = 0.2
bias_value = -1.3863 # log_e ((1 - 0.2) / 0.2)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
if "bbox_attention" in name and "bias" not in name:
self.assertLess(
100000,
abs(param.data.max().item()),
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# Modifed from RT-DETR
elif "class_embed" in name and "bias" in name:
bias_tensor = torch.full_like(param.data, bias_value)
torch.testing.assert_close(
param.data,
bias_tensor,
atol=1e-4,
rtol=1e-4,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
elif "activation_fn" in name and config.activation_function == "prelu":
self.assertTrue(
param.data.mean() == 0.25,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
elif "backbone.conv_encoder.model" in name:
continue
elif "self_attn.in_proj_weight" in name:
self.assertIn(
((param.data.mean() * 1e2).round() / 1e2).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
else:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
TOLERANCE = 1e-4
CHECKPOINT = "IDEA-Research/dab-detr-resnet-50"
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_timm
@require_vision
@slow
class DabDetrModelIntegrationTests(unittest.TestCase):
@cached_property
def default_image_processor(self):
return ConditionalDetrImageProcessor.from_pretrained(CHECKPOINT) if is_vision_available() else None
def test_inference_no_head(self):
model = DabDetrModel.from_pretrained(CHECKPOINT).to(torch_device)
image_processor = self.default_image_processor
image = prepare_img()
encoding = image_processor(images=image, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(pixel_values=encoding.pixel_values)
expected_shape = torch.Size((1, 300, 256))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.tensor(
[[-0.4879, -0.2594, 0.4524], [-0.4997, -0.4258, 0.4329], [-0.8220, -0.4996, 0.0577]]
).to(torch_device)
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=2e-4, rtol=2e-4)
def test_inference_object_detection_head(self):
model = DabDetrForObjectDetection.from_pretrained(CHECKPOINT).to(torch_device)
image_processor = self.default_image_processor
image = prepare_img()
encoding = image_processor(images=image, return_tensors="pt").to(torch_device)
pixel_values = encoding["pixel_values"].to(torch_device)
with torch.no_grad():
outputs = model(pixel_values)
# verify logits + box predictions
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels))
self.assertEqual(outputs.logits.shape, expected_shape_logits)
expected_slice_logits = torch.tensor(
[[-10.1765, -5.5243, -8.9324], [-9.8138, -5.6721, -7.5161], [-10.3054, -5.6081, -8.5931]]
).to(torch_device)
torch.testing.assert_close(outputs.logits[0, :3, :3], expected_slice_logits, atol=3e-4, rtol=3e-4)
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
expected_slice_boxes = torch.tensor(
[[0.3708, 0.3000, 0.2753], [0.5211, 0.6125, 0.9495], [0.2897, 0.6730, 0.5459]]
).to(torch_device)
torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4, rtol=1e-4)
# verify postprocessing
results = image_processor.post_process_object_detection(
outputs, threshold=0.3, target_sizes=[image.size[::-1]]
)[0]
expected_scores = torch.tensor([0.8732, 0.8563, 0.8554, 0.6079, 0.5896]).to(torch_device)
expected_labels = [17, 75, 17, 75, 63]
expected_boxes = torch.tensor([14.6970, 49.3892, 320.5165, 469.2765]).to(torch_device)
self.assertEqual(len(results["scores"]), 5)
torch.testing.assert_close(results["scores"], expected_scores, atol=1e-4, rtol=1e-4)
self.assertSequenceEqual(results["labels"].tolist(), expected_labels)
torch.testing.assert_close(results["boxes"][0, :], expected_boxes, atol=1e-4, rtol=1e-4)

View File

@ -161,6 +161,16 @@ SPECIAL_CASES_TO_ALLOW = {
"giou_loss_coefficient",
"mask_loss_coefficient",
],
"DabDetrConfig": [
"dilation",
"bbox_cost",
"bbox_loss_coefficient",
"class_cost",
"cls_loss_coefficient",
"focal_alpha",
"giou_cost",
"giou_loss_coefficient",
],
"DetrConfig": [
"bbox_cost",
"bbox_loss_coefficient",