mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-15 10:38:23 +06:00
Add DAB-DETR for object detection (#30803)
* initial commit * encoder+decoder layer changes WIP * architecture checks * working version of detection + segmentation * fix modeling outputs * fix return dict + output att/hs * found the position embedding masking bug * pre-training version * added iamge processors * typo in init.py * iterupdate set to false * fixed num_labels in class_output linear layer bias init * multihead attention shape fixes * test improvements * test update * dab-detr model_doc update * dab-detr model_doc update2 * test fix:test_retain_grad_hidden_states_attentions * config file clean and renaming variables * config file clean and renaming variables fix * updated convert_to_hf file * small fixes * style and qulity checks * return_dict fix * Merge branch main into add_dab_detr * small comment fix * skip test_inputs_embeds test * image processor updates + image processor test updates * check copies test fix update * updates for check_copies.py test * updates for check_copies.py test2 * tied weights fix * fixed image processing tests and fixed shared weights issues * added numpy nd array option to get_Expected_values method in test_image_processing_dab_detr.py * delete prints from test file * SafeTensor modification to solve HF Trainer issue * removing the safetensor modifications * make fix copies and hf uplaod has been added. * fixed index.md * fixed repo consistency * styel fix and dabdetrimageprocessor docstring update * requested modifications after the first review * Update src/transformers/models/dab_detr/image_processing_dab_detr.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * repo consistency has been fixed * update copied NestedTensor function after main merge * Update src/transformers/models/dab_detr/modeling_dab_detr.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * temp commit * temp commit2 * temp commit 3 * unit tests are fixed * fixed repo consistency * updated expected_boxes varible values based on related notebook results in DABDETRIntegrationTests file. * temporarialy config modifications and repo consistency fixes * Put dilation parameter back to config * pattern embeddings have been added to the rename_keys method * add dilation comment to config + add as an exception in check_config_attributes SPECIAL CASES * delete FeatureExtractor part from docs.md * requested modifications in modeling_dab_detr.py * [run_slow] dab_detr * deleted last segmentation code part, updated conversion script and changed the hf path in test files * temp commit of requested modifications * temp commit of requested modifications 2 * updated config file, resolved codepaths and refactored conversion script * updated decodelayer block types and refactored conversion script * style and quality update * small modifications based on the request * attentions are refactored * removed loss functions from modeling file, added loss function to lossutils, tried to move the MLP layer generation to config but it failed * deleted imageprocessor * fixed conversion script + quality and style * fixed config_att * [run_slow] dab_detr * changing model path in conversion file and in test file * fix Decoder variable naming * testing the old loss function * switched back to the new loss function and testing with the odl attention functions * switched back to the new last good result modeling file * moved back to the version when I asked the review * missing new line at the end of the file * old version test * turn back to newest mdoel versino but change image processor * style fix * style fix after merge main * [run_slow] dab_detr * [run_slow] dab_detr * added device and type for head bias data part * [run_slow] dab_detr * fixed model head bias data fill * changed test_inference_object_detection_head assertTrues to torch test assert_close * fixes part 1 * quality update * self.bbox_embed in decoder has been restored * changed Assert true torch closeall methods to torch testing assertclose * modelcard markdown file has been updated * deleted intemediate list from decoder module --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
parent
fe52679e74
commit
8d73a38606
@ -643,6 +643,8 @@
|
|||||||
title: ConvNeXTV2
|
title: ConvNeXTV2
|
||||||
- local: model_doc/cvt
|
- local: model_doc/cvt
|
||||||
title: CvT
|
title: CvT
|
||||||
|
- local: model_doc/dab-detr
|
||||||
|
title: DAB-DETR
|
||||||
- local: model_doc/deformable_detr
|
- local: model_doc/deformable_detr
|
||||||
title: Deformable DETR
|
title: Deformable DETR
|
||||||
- local: model_doc/deit
|
- local: model_doc/deit
|
||||||
|
@ -110,6 +110,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| [CPM-Ant](model_doc/cpmant) | ✅ | ❌ | ❌ |
|
| [CPM-Ant](model_doc/cpmant) | ✅ | ❌ | ❌ |
|
||||||
| [CTRL](model_doc/ctrl) | ✅ | ✅ | ❌ |
|
| [CTRL](model_doc/ctrl) | ✅ | ✅ | ❌ |
|
||||||
| [CvT](model_doc/cvt) | ✅ | ✅ | ❌ |
|
| [CvT](model_doc/cvt) | ✅ | ✅ | ❌ |
|
||||||
|
| [DAB-DETR](model_doc/dab-detr) | ✅ | ❌ | ❌ |
|
||||||
| [DAC](model_doc/dac) | ✅ | ❌ | ❌ |
|
| [DAC](model_doc/dac) | ✅ | ❌ | ❌ |
|
||||||
| [Data2VecAudio](model_doc/data2vec) | ✅ | ❌ | ❌ |
|
| [Data2VecAudio](model_doc/data2vec) | ✅ | ❌ | ❌ |
|
||||||
| [Data2VecText](model_doc/data2vec) | ✅ | ❌ | ❌ |
|
| [Data2VecText](model_doc/data2vec) | ✅ | ❌ | ❌ |
|
||||||
|
119
docs/source/en/model_doc/dab-detr.md
Normal file
119
docs/source/en/model_doc/dab-detr.md
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
# DAB-DETR
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The DAB-DETR model was proposed in [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](https://arxiv.org/abs/2201.12329) by Shilong Liu, Feng Li, Hao Zhang, Xiao Yang, Xianbiao Qi, Hang Su, Jun Zhu, Lei Zhang.
|
||||||
|
DAB-DETR is an enhanced variant of Conditional DETR. It utilizes dynamically updated anchor boxes to provide both a reference query point (x, y) and a reference anchor size (w, h), improving cross-attention computation. This new approach achieves 45.7% AP when trained for 50 epochs with a single ResNet-50 model as the backbone.
|
||||||
|
|
||||||
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dab_detr_convergence_plot.png"
|
||||||
|
alt="drawing" width="600"/>
|
||||||
|
|
||||||
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
|
*We present in this paper a novel query formulation using dynamic anchor boxes
|
||||||
|
for DETR (DEtection TRansformer) and offer a deeper understanding of the role
|
||||||
|
of queries in DETR. This new formulation directly uses box coordinates as queries
|
||||||
|
in Transformer decoders and dynamically updates them layer-by-layer. Using box
|
||||||
|
coordinates not only helps using explicit positional priors to improve the query-to-feature similarity and eliminate the slow training convergence issue in DETR,
|
||||||
|
but also allows us to modulate the positional attention map using the box width
|
||||||
|
and height information. Such a design makes it clear that queries in DETR can be
|
||||||
|
implemented as performing soft ROI pooling layer-by-layer in a cascade manner.
|
||||||
|
As a result, it leads to the best performance on MS-COCO benchmark among
|
||||||
|
the DETR-like detection models under the same setting, e.g., AP 45.7% using
|
||||||
|
ResNet50-DC5 as backbone trained in 50 epochs. We also conducted extensive
|
||||||
|
experiments to confirm our analysis and verify the effectiveness of our methods.*
|
||||||
|
|
||||||
|
This model was contributed by [davidhajdu](https://huggingface.co/davidhajdu).
|
||||||
|
The original code can be found [here](https://github.com/IDEA-Research/DAB-DETR).
|
||||||
|
|
||||||
|
## How to Get Started with the Model
|
||||||
|
|
||||||
|
Use the code below to get started with the model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoModelForObjectDetection, AutoImageProcessor
|
||||||
|
|
||||||
|
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||||
|
image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab-detr-resnet-50")
|
||||||
|
model = AutoModelForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50")
|
||||||
|
|
||||||
|
inputs = image_processor(images=image, return_tensors="pt")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
|
||||||
|
score, label = score.item(), label_id.item()
|
||||||
|
box = [round(i, 2) for i in box.tolist()]
|
||||||
|
print(f"{model.config.id2label[label]}: {score:.2f} {box}")
|
||||||
|
```
|
||||||
|
This should output
|
||||||
|
```
|
||||||
|
cat: 0.87 [14.7, 49.39, 320.52, 469.28]
|
||||||
|
remote: 0.86 [41.08, 72.37, 173.39, 117.2]
|
||||||
|
cat: 0.86 [344.45, 19.43, 639.85, 367.86]
|
||||||
|
remote: 0.61 [334.27, 75.93, 367.92, 188.81]
|
||||||
|
couch: 0.59 [-0.04, 1.34, 639.9, 477.09]
|
||||||
|
```
|
||||||
|
|
||||||
|
There are three other ways to instantiate a DAB-DETR model (depending on what you prefer):
|
||||||
|
|
||||||
|
Option 1: Instantiate DAB-DETR with pre-trained weights for entire model
|
||||||
|
```py
|
||||||
|
>>> from transformers import DabDetrForObjectDetection
|
||||||
|
|
||||||
|
>>> model = DabDetrForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50")
|
||||||
|
```
|
||||||
|
|
||||||
|
Option 2: Instantiate DAB-DETR with randomly initialized weights for Transformer, but pre-trained weights for backbone
|
||||||
|
```py
|
||||||
|
>>> from transformers import DabDetrConfig, DabDetrForObjectDetection
|
||||||
|
|
||||||
|
>>> config = DabDetrConfig()
|
||||||
|
>>> model = DabDetrForObjectDetection(config)
|
||||||
|
```
|
||||||
|
Option 3: Instantiate DAB-DETR with randomly initialized weights for backbone + Transformer
|
||||||
|
```py
|
||||||
|
>>> config = DabDetrConfig(use_pretrained_backbone=False)
|
||||||
|
>>> model = DabDetrForObjectDetection(config)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## DabDetrConfig
|
||||||
|
|
||||||
|
[[autodoc]] DabDetrConfig
|
||||||
|
|
||||||
|
## DabDetrModel
|
||||||
|
|
||||||
|
[[autodoc]] DabDetrModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## DabDetrForObjectDetection
|
||||||
|
|
||||||
|
[[autodoc]] DabDetrForObjectDetection
|
||||||
|
- forward
|
@ -328,6 +328,7 @@ _import_structure = {
|
|||||||
"CTRLTokenizer",
|
"CTRLTokenizer",
|
||||||
],
|
],
|
||||||
"models.cvt": ["CvtConfig"],
|
"models.cvt": ["CvtConfig"],
|
||||||
|
"models.dab_detr": ["DabDetrConfig"],
|
||||||
"models.dac": ["DacConfig", "DacFeatureExtractor"],
|
"models.dac": ["DacConfig", "DacFeatureExtractor"],
|
||||||
"models.data2vec": [
|
"models.data2vec": [
|
||||||
"Data2VecAudioConfig",
|
"Data2VecAudioConfig",
|
||||||
@ -1898,6 +1899,13 @@ else:
|
|||||||
"CvtPreTrainedModel",
|
"CvtPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.dab_detr"].extend(
|
||||||
|
[
|
||||||
|
"DabDetrForObjectDetection",
|
||||||
|
"DabDetrModel",
|
||||||
|
"DabDetrPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.dac"].extend(
|
_import_structure["models.dac"].extend(
|
||||||
[
|
[
|
||||||
"DacModel",
|
"DacModel",
|
||||||
@ -5387,6 +5395,9 @@ if TYPE_CHECKING:
|
|||||||
CTRLTokenizer,
|
CTRLTokenizer,
|
||||||
)
|
)
|
||||||
from .models.cvt import CvtConfig
|
from .models.cvt import CvtConfig
|
||||||
|
from .models.dab_detr import (
|
||||||
|
DabDetrConfig,
|
||||||
|
)
|
||||||
from .models.dac import (
|
from .models.dac import (
|
||||||
DacConfig,
|
DacConfig,
|
||||||
DacFeatureExtractor,
|
DacFeatureExtractor,
|
||||||
@ -6926,6 +6937,11 @@ if TYPE_CHECKING:
|
|||||||
CvtModel,
|
CvtModel,
|
||||||
CvtPreTrainedModel,
|
CvtPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.dab_detr import (
|
||||||
|
DabDetrForObjectDetection,
|
||||||
|
DabDetrModel,
|
||||||
|
DabDetrPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.dac import (
|
from .models.dac import (
|
||||||
DacModel,
|
DacModel,
|
||||||
DacPreTrainedModel,
|
DacPreTrainedModel,
|
||||||
|
@ -217,6 +217,7 @@ ACT2CLS = {
|
|||||||
"silu": nn.SiLU,
|
"silu": nn.SiLU,
|
||||||
"swish": nn.SiLU,
|
"swish": nn.SiLU,
|
||||||
"tanh": nn.Tanh,
|
"tanh": nn.Tanh,
|
||||||
|
"prelu": nn.PReLU,
|
||||||
}
|
}
|
||||||
ACT2FN = ClassInstantier(ACT2CLS)
|
ACT2FN = ClassInstantier(ACT2CLS)
|
||||||
|
|
||||||
|
@ -128,6 +128,7 @@ LOSS_MAPPING = {
|
|||||||
"ForObjectDetection": ForObjectDetectionLoss,
|
"ForObjectDetection": ForObjectDetectionLoss,
|
||||||
"DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
"DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
||||||
"ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
"ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
||||||
|
"DabDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
||||||
"GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
"GroundingDinoForObjectDetection": DeformableDetrForObjectDetectionLoss,
|
||||||
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
|
"ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
|
||||||
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
|
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
|
||||||
|
@ -63,6 +63,7 @@ from . import (
|
|||||||
cpmant,
|
cpmant,
|
||||||
ctrl,
|
ctrl,
|
||||||
cvt,
|
cvt,
|
||||||
|
dab_detr,
|
||||||
dac,
|
dac,
|
||||||
data2vec,
|
data2vec,
|
||||||
dbrx,
|
dbrx,
|
||||||
|
@ -79,6 +79,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("cpmant", "CpmAntConfig"),
|
("cpmant", "CpmAntConfig"),
|
||||||
("ctrl", "CTRLConfig"),
|
("ctrl", "CTRLConfig"),
|
||||||
("cvt", "CvtConfig"),
|
("cvt", "CvtConfig"),
|
||||||
|
("dab-detr", "DabDetrConfig"),
|
||||||
("dac", "DacConfig"),
|
("dac", "DacConfig"),
|
||||||
("data2vec-audio", "Data2VecAudioConfig"),
|
("data2vec-audio", "Data2VecAudioConfig"),
|
||||||
("data2vec-text", "Data2VecTextConfig"),
|
("data2vec-text", "Data2VecTextConfig"),
|
||||||
@ -399,6 +400,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("cpmant", "CPM-Ant"),
|
("cpmant", "CPM-Ant"),
|
||||||
("ctrl", "CTRL"),
|
("ctrl", "CTRL"),
|
||||||
("cvt", "CvT"),
|
("cvt", "CvT"),
|
||||||
|
("dab-detr", "DAB-DETR"),
|
||||||
("dac", "DAC"),
|
("dac", "DAC"),
|
||||||
("data2vec-audio", "Data2VecAudio"),
|
("data2vec-audio", "Data2VecAudio"),
|
||||||
("data2vec-text", "Data2VecText"),
|
("data2vec-text", "Data2VecText"),
|
||||||
|
@ -78,6 +78,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("cpmant", "CpmAntModel"),
|
("cpmant", "CpmAntModel"),
|
||||||
("ctrl", "CTRLModel"),
|
("ctrl", "CTRLModel"),
|
||||||
("cvt", "CvtModel"),
|
("cvt", "CvtModel"),
|
||||||
|
("dab-detr", "DabDetrModel"),
|
||||||
("dac", "DacModel"),
|
("dac", "DacModel"),
|
||||||
("data2vec-audio", "Data2VecAudioModel"),
|
("data2vec-audio", "Data2VecAudioModel"),
|
||||||
("data2vec-text", "Data2VecTextModel"),
|
("data2vec-text", "Data2VecTextModel"),
|
||||||
@ -592,6 +593,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
|
|||||||
("conditional_detr", "ConditionalDetrModel"),
|
("conditional_detr", "ConditionalDetrModel"),
|
||||||
("convnext", "ConvNextModel"),
|
("convnext", "ConvNextModel"),
|
||||||
("convnextv2", "ConvNextV2Model"),
|
("convnextv2", "ConvNextV2Model"),
|
||||||
|
("dab-detr", "DabDetrModel"),
|
||||||
("data2vec-vision", "Data2VecVisionModel"),
|
("data2vec-vision", "Data2VecVisionModel"),
|
||||||
("deformable_detr", "DeformableDetrModel"),
|
("deformable_detr", "DeformableDetrModel"),
|
||||||
("deit", "DeiTModel"),
|
("deit", "DeiTModel"),
|
||||||
@ -890,6 +892,7 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
|||||||
[
|
[
|
||||||
# Model for Object Detection mapping
|
# Model for Object Detection mapping
|
||||||
("conditional_detr", "ConditionalDetrForObjectDetection"),
|
("conditional_detr", "ConditionalDetrForObjectDetection"),
|
||||||
|
("dab-detr", "DabDetrForObjectDetection"),
|
||||||
("deformable_detr", "DeformableDetrForObjectDetection"),
|
("deformable_detr", "DeformableDetrForObjectDetection"),
|
||||||
("deta", "DetaForObjectDetection"),
|
("deta", "DetaForObjectDetection"),
|
||||||
("detr", "DetrForObjectDetection"),
|
("detr", "DetrForObjectDetection"),
|
||||||
|
@ -52,7 +52,7 @@ class ConditionalDetrConfig(PretrainedConfig):
|
|||||||
Number of object queries, i.e. detection slots. This is the maximal number of objects
|
Number of object queries, i.e. detection slots. This is the maximal number of objects
|
||||||
[`ConditionalDetrModel`] can detect in a single image. For COCO, we recommend 100 queries.
|
[`ConditionalDetrModel`] can detect in a single image. For COCO, we recommend 100 queries.
|
||||||
d_model (`int`, *optional*, defaults to 256):
|
d_model (`int`, *optional*, defaults to 256):
|
||||||
Dimension of the layers.
|
This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others.
|
||||||
encoder_layers (`int`, *optional*, defaults to 6):
|
encoder_layers (`int`, *optional*, defaults to 6):
|
||||||
Number of encoder layers.
|
Number of encoder layers.
|
||||||
decoder_layers (`int`, *optional*, defaults to 6):
|
decoder_layers (`int`, *optional*, defaults to 6):
|
||||||
|
@ -74,6 +74,8 @@ class ConditionalDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
|
|||||||
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
|
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
|
||||||
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
|
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
|
||||||
layernorm.
|
layernorm.
|
||||||
|
reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`):
|
||||||
|
Reference points (reference points of each layer of the decoder).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
||||||
@ -116,6 +118,8 @@ class ConditionalDetrModelOutput(Seq2SeqModelOutput):
|
|||||||
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
|
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, sequence_length, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
|
||||||
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
|
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
|
||||||
layernorm.
|
layernorm.
|
||||||
|
reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 2 (anchor points))`):
|
||||||
|
Reference points (reference points of each layer of the decoder).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
intermediate_hidden_states: Optional[torch.FloatTensor] = None
|
||||||
|
28
src/transformers/models/dab_detr/__init__.py
Normal file
28
src/transformers/models/dab_detr/__init__.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_dab_detr import *
|
||||||
|
from .modeling_dab_detr import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
260
src/transformers/models/dab_detr/configuration_dab_detr.py
Normal file
260
src/transformers/models/dab_detr/configuration_dab_detr.py
Normal file
@ -0,0 +1,260 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""DAB-DETR model configuration"""
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...utils import logging
|
||||||
|
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||||
|
from ..auto import CONFIG_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DabDetrConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`DabDetrModel`]. It is used to instantiate
|
||||||
|
a DAB-DETR model according to the specified arguments, defining the model architecture. Instantiating a
|
||||||
|
configuration with the defaults will yield a similar configuration to that of the DAB-DETR
|
||||||
|
[IDEA-Research/dab_detr-base](https://huggingface.co/IDEA-Research/dab_detr-base) architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_timm_backbone (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
|
||||||
|
API.
|
||||||
|
backbone_config (`PretrainedConfig` or `dict`, *optional*):
|
||||||
|
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
|
||||||
|
case it will default to `ResNetConfig()`.
|
||||||
|
backbone (`str`, *optional*, defaults to `"resnet50"`):
|
||||||
|
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
||||||
|
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
||||||
|
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
||||||
|
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use pretrained weights for the backbone.
|
||||||
|
backbone_kwargs (`dict`, *optional*):
|
||||||
|
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||||
|
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||||
|
num_queries (`int`, *optional*, defaults to 300):
|
||||||
|
Number of object queries, i.e. detection slots. This is the maximal number of objects
|
||||||
|
[`DabDetrModel`] can detect in a single image. For COCO, we recommend 100 queries.
|
||||||
|
encoder_layers (`int`, *optional*, defaults to 6):
|
||||||
|
Number of encoder layers.
|
||||||
|
encoder_ffn_dim (`int`, *optional*, defaults to 2048):
|
||||||
|
Dimension of the "intermediate" (often named feed-forward) layer in encoder.
|
||||||
|
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
decoder_layers (`int`, *optional*, defaults to 6):
|
||||||
|
Number of decoder layers.
|
||||||
|
decoder_ffn_dim (`int`, *optional*, defaults to 2048):
|
||||||
|
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
||||||
|
decoder_attention_heads (`int`, *optional*, defaults to 8):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
||||||
|
Indicates whether the transformer model architecture is an encoder-decoder or not.
|
||||||
|
activation_function (`str` or `function`, *optional*, defaults to `"prelu"`):
|
||||||
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||||
|
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||||
|
hidden_size (`int`, *optional*, defaults to 256):
|
||||||
|
This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.1):
|
||||||
|
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for activations inside the fully connected layer.
|
||||||
|
init_std (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
init_xavier_std (`float`, *optional*, defaults to 1.0):
|
||||||
|
The scaling factor used for the Xavier initialization gain in the HM Attention map module.
|
||||||
|
auxiliary_loss (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||||
|
dilation (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when `use_timm_backbone` = `True`.
|
||||||
|
class_cost (`float`, *optional*, defaults to 2):
|
||||||
|
Relative weight of the classification error in the Hungarian matching cost.
|
||||||
|
bbox_cost (`float`, *optional*, defaults to 5):
|
||||||
|
Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
|
||||||
|
giou_cost (`float`, *optional*, defaults to 2):
|
||||||
|
Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
|
||||||
|
cls_loss_coefficient (`float`, *optional*, defaults to 2):
|
||||||
|
Relative weight of the classification loss in the object detection loss function.
|
||||||
|
bbox_loss_coefficient (`float`, *optional*, defaults to 5):
|
||||||
|
Relative weight of the L1 bounding box loss in the object detection loss.
|
||||||
|
giou_loss_coefficient (`float`, *optional*, defaults to 2):
|
||||||
|
Relative weight of the generalized IoU loss in the object detection loss.
|
||||||
|
focal_alpha (`float`, *optional*, defaults to 0.25):
|
||||||
|
Alpha parameter in the focal loss.
|
||||||
|
temperature_height (`int`, *optional*, defaults to 20):
|
||||||
|
Temperature parameter to tune the flatness of positional attention (HEIGHT)
|
||||||
|
temperature_width (`int`, *optional*, defaults to 20):
|
||||||
|
Temperature parameter to tune the flatness of positional attention (WIDTH)
|
||||||
|
query_dim (`int`, *optional*, defaults to 4):
|
||||||
|
Query dimension parameter represents the size of the output vector.
|
||||||
|
random_refpoints_xy (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to fix the x and y coordinates of the anchor boxes with random initialization.
|
||||||
|
keep_query_pos (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to concatenate the projected positional embedding from the object query into the original query (key) in every decoder layer.
|
||||||
|
num_patterns (`int`, *optional*, defaults to 0):
|
||||||
|
Number of pattern embeddings.
|
||||||
|
normalize_before (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether we use a normalization layer in the Encoder or not.
|
||||||
|
sine_position_embedding_scale (`float`, *optional*, defaults to 'None'):
|
||||||
|
Scaling factor applied to the normalized positional encodings.
|
||||||
|
initializer_bias_prior_prob (`float`, *optional*):
|
||||||
|
The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
|
||||||
|
If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
|
||||||
|
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import DabDetrConfig, DabDetrModel
|
||||||
|
|
||||||
|
>>> # Initializing a DAB-DETR IDEA-Research/dab_detr-base style configuration
|
||||||
|
>>> configuration = DabDetrConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model (with random weights) from the IDEA-Research/dab_detr-base style configuration
|
||||||
|
>>> model = DabDetrModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "dab-detr"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
attribute_map = {
|
||||||
|
"num_attention_heads": "encoder_attention_heads",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
use_timm_backbone=True,
|
||||||
|
backbone_config=None,
|
||||||
|
backbone="resnet50",
|
||||||
|
use_pretrained_backbone=True,
|
||||||
|
backbone_kwargs=None,
|
||||||
|
num_queries=300,
|
||||||
|
encoder_layers=6,
|
||||||
|
encoder_ffn_dim=2048,
|
||||||
|
encoder_attention_heads=8,
|
||||||
|
decoder_layers=6,
|
||||||
|
decoder_ffn_dim=2048,
|
||||||
|
decoder_attention_heads=8,
|
||||||
|
is_encoder_decoder=True,
|
||||||
|
activation_function="prelu",
|
||||||
|
hidden_size=256,
|
||||||
|
dropout=0.1,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
activation_dropout=0.0,
|
||||||
|
init_std=0.02,
|
||||||
|
init_xavier_std=1.0,
|
||||||
|
auxiliary_loss=False,
|
||||||
|
dilation=False,
|
||||||
|
class_cost=2,
|
||||||
|
bbox_cost=5,
|
||||||
|
giou_cost=2,
|
||||||
|
cls_loss_coefficient=2,
|
||||||
|
bbox_loss_coefficient=5,
|
||||||
|
giou_loss_coefficient=2,
|
||||||
|
focal_alpha=0.25,
|
||||||
|
temperature_height=20,
|
||||||
|
temperature_width=20,
|
||||||
|
query_dim=4,
|
||||||
|
random_refpoints_xy=False,
|
||||||
|
keep_query_pos=False,
|
||||||
|
num_patterns=0,
|
||||||
|
normalize_before=False,
|
||||||
|
sine_position_embedding_scale=None,
|
||||||
|
initializer_bias_prior_prob=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if query_dim != 4:
|
||||||
|
raise ValueError("The query dimensions has to be 4.")
|
||||||
|
|
||||||
|
# We default to values which were previously hard-coded in the model. This enables configurability of the config
|
||||||
|
# while keeping the default behavior the same.
|
||||||
|
if use_timm_backbone and backbone_kwargs is None:
|
||||||
|
backbone_kwargs = {}
|
||||||
|
if dilation:
|
||||||
|
backbone_kwargs["output_stride"] = 16
|
||||||
|
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
|
||||||
|
backbone_kwargs["in_chans"] = 3 # num_channels
|
||||||
|
# Backwards compatibility
|
||||||
|
elif not use_timm_backbone and backbone in (None, "resnet50"):
|
||||||
|
if backbone_config is None:
|
||||||
|
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
|
||||||
|
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
|
||||||
|
elif isinstance(backbone_config, dict):
|
||||||
|
backbone_model_type = backbone_config.get("model_type")
|
||||||
|
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||||
|
backbone_config = config_class.from_dict(backbone_config)
|
||||||
|
backbone = None
|
||||||
|
# set timm attributes to None
|
||||||
|
dilation = None
|
||||||
|
|
||||||
|
verify_backbone_config_arguments(
|
||||||
|
use_timm_backbone=use_timm_backbone,
|
||||||
|
use_pretrained_backbone=use_pretrained_backbone,
|
||||||
|
backbone=backbone,
|
||||||
|
backbone_config=backbone_config,
|
||||||
|
backbone_kwargs=backbone_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.use_timm_backbone = use_timm_backbone
|
||||||
|
self.backbone_config = backbone_config
|
||||||
|
self.num_queries = num_queries
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.encoder_ffn_dim = encoder_ffn_dim
|
||||||
|
self.encoder_layers = encoder_layers
|
||||||
|
self.encoder_attention_heads = encoder_attention_heads
|
||||||
|
self.decoder_ffn_dim = decoder_ffn_dim
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.decoder_attention_heads = decoder_attention_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.activation_dropout = activation_dropout
|
||||||
|
self.activation_function = activation_function
|
||||||
|
self.init_std = init_std
|
||||||
|
self.init_xavier_std = init_xavier_std
|
||||||
|
self.num_hidden_layers = encoder_layers
|
||||||
|
self.auxiliary_loss = auxiliary_loss
|
||||||
|
self.backbone = backbone
|
||||||
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
|
self.backbone_kwargs = backbone_kwargs
|
||||||
|
# Hungarian matcher
|
||||||
|
self.class_cost = class_cost
|
||||||
|
self.bbox_cost = bbox_cost
|
||||||
|
self.giou_cost = giou_cost
|
||||||
|
# Loss coefficients
|
||||||
|
self.cls_loss_coefficient = cls_loss_coefficient
|
||||||
|
self.bbox_loss_coefficient = bbox_loss_coefficient
|
||||||
|
self.giou_loss_coefficient = giou_loss_coefficient
|
||||||
|
self.focal_alpha = focal_alpha
|
||||||
|
self.query_dim = query_dim
|
||||||
|
self.random_refpoints_xy = random_refpoints_xy
|
||||||
|
self.keep_query_pos = keep_query_pos
|
||||||
|
self.num_patterns = num_patterns
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
self.temperature_width = temperature_width
|
||||||
|
self.temperature_height = temperature_height
|
||||||
|
self.sine_position_embedding_scale = sine_position_embedding_scale
|
||||||
|
self.initializer_bias_prior_prob = initializer_bias_prior_prob
|
||||||
|
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["DabDetrConfig"]
|
@ -0,0 +1,233 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Convert DAB-DETR checkpoints."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from transformers import ConditionalDetrImageProcessor, DabDetrConfig, DabDetrForObjectDetection
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||||
|
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
|
||||||
|
# for dab-DETR, also convert reference point head and query scale MLP
|
||||||
|
r"input_proj\.(bias|weight)": r"input_projection.\1",
|
||||||
|
r"refpoint_embed\.weight": r"query_refpoint_embeddings.weight",
|
||||||
|
r"class_embed\.(bias|weight)": r"class_embed.\1",
|
||||||
|
# negative lookbehind because of the overlap
|
||||||
|
r"(?<!transformer\.decoder\.)bbox_embed\.layers\.(\d+)\.(bias|weight)": r"bbox_predictor.layers.\1.\2",
|
||||||
|
r"transformer\.encoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"encoder.query_scale.layers.\1.\2",
|
||||||
|
r"transformer\.decoder\.bbox_embed\.layers\.(\d+)\.(bias|weight)": r"decoder.bbox_embed.layers.\1.\2",
|
||||||
|
r"transformer\.decoder\.norm\.(bias|weight)": r"decoder.layernorm.\1",
|
||||||
|
r"transformer\.decoder\.ref_point_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_point_head.layers.\1.\2",
|
||||||
|
r"transformer\.decoder\.ref_anchor_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_anchor_head.layers.\1.\2",
|
||||||
|
r"transformer\.decoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"decoder.query_scale.layers.\1.\2",
|
||||||
|
r"transformer\.decoder\.layers\.0\.ca_qpos_proj\.(bias|weight)": r"decoder.layers.0.cross_attn.cross_attn_query_pos_proj.\1",
|
||||||
|
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + activation function
|
||||||
|
# output projection
|
||||||
|
r"transformer\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"encoder.layers.\1.self_attn.out_proj.\2",
|
||||||
|
# FFN layers
|
||||||
|
r"transformer\.encoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"encoder.layers.\1.fc\2.\3",
|
||||||
|
# normalization layers
|
||||||
|
# nm1
|
||||||
|
r"transformer\.encoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"encoder.layers.\1.self_attn_layer_norm.\2",
|
||||||
|
# nm2
|
||||||
|
r"transformer\.encoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"encoder.layers.\1.final_layer_norm.\2",
|
||||||
|
# activation function weight
|
||||||
|
r"transformer\.encoder\.layers\.(\d+)\.activation\.weight": r"encoder.layers.\1.activation_fn.weight",
|
||||||
|
#########################################################################################################################################
|
||||||
|
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + activiation function weight
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn.output_proj.\2",
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.cross_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn.output_proj.\2",
|
||||||
|
# FFNs
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"decoder.layers.\1.mlp.fc\2.\3",
|
||||||
|
# nm1
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_layer_norm.\2",
|
||||||
|
# nm2
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_layer_norm.\2",
|
||||||
|
# nm3
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.norm3\.(bias|weight)": r"decoder.layers.\1.mlp.final_layer_norm.\2",
|
||||||
|
# activation function weight
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.activation\.weight": r"decoder.layers.\1.mlp.activation_fn.weight",
|
||||||
|
# q, k, v projections and biases in self-attention in decoder
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.sa_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_content_proj.\2",
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.sa_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_content_proj.\2",
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.sa_qpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_pos_proj.\2",
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.sa_kpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_pos_proj.\2",
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.sa_v_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_value_proj.\2",
|
||||||
|
# q, k, v projections in cross-attention in decoder
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.ca_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_content_proj.\2",
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.ca_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_content_proj.\2",
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.ca_kpos_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_pos_proj.\2",
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.ca_v_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_value_proj.\2",
|
||||||
|
r"transformer\.decoder\.layers\.(\d+)\.ca_qpos_sine_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_pos_sine_proj.\2",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.mllama.convert_mllama_weights_to_hf.convert_old_keys_to_new_keys
|
||||||
|
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
|
||||||
|
"""
|
||||||
|
This function should be applied only once, on the concatenated keys to efficiently rename using
|
||||||
|
the key mappings.
|
||||||
|
"""
|
||||||
|
output_dict = {}
|
||||||
|
if state_dict_keys is not None:
|
||||||
|
old_text = "\n".join(state_dict_keys)
|
||||||
|
new_text = old_text
|
||||||
|
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
||||||
|
if replacement is None:
|
||||||
|
new_text = re.sub(pattern, "", new_text) # an empty line
|
||||||
|
continue
|
||||||
|
new_text = re.sub(pattern, replacement, new_text)
|
||||||
|
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
||||||
|
return output_dict
|
||||||
|
|
||||||
|
|
||||||
|
def write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub):
|
||||||
|
logger.info("Converting image processor...")
|
||||||
|
format = "coco_detection"
|
||||||
|
image_processor = ConditionalDetrImageProcessor(format=format)
|
||||||
|
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||||
|
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
image_processor.push_to_hub(repo_id=model_name, commit_message="Add new image processor")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub):
|
||||||
|
# load modified config. Why? After loading the default config, the backbone kwargs are already set.
|
||||||
|
if "dc5" in model_name:
|
||||||
|
config = DabDetrConfig(dilation=True)
|
||||||
|
else:
|
||||||
|
# load default config
|
||||||
|
config = DabDetrConfig()
|
||||||
|
# set other attributes
|
||||||
|
if "dab-detr-resnet-50-dc5" == model_name:
|
||||||
|
config.temperature_height = 10
|
||||||
|
config.temperature_width = 10
|
||||||
|
if "fixxy" in model_name:
|
||||||
|
config.random_refpoints_xy = True
|
||||||
|
if "pat3" in model_name:
|
||||||
|
config.num_patterns = 3
|
||||||
|
# only when the number of patterns (num_patterns parameter in config) are more than 0 like r50-pat3 or r50dc5-pat3
|
||||||
|
ORIGINAL_TO_CONVERTED_KEY_MAPPING.update({r"transformer.patterns.weight": r"patterns.weight"})
|
||||||
|
|
||||||
|
config.num_labels = 91
|
||||||
|
repo_id = "huggingface/label-files"
|
||||||
|
filename = "coco-detection-id2label.json"
|
||||||
|
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||||
|
id2label = {int(k): v for k, v in id2label.items()}
|
||||||
|
config.id2label = id2label
|
||||||
|
config.label2id = {v: k for k, v in id2label.items()}
|
||||||
|
# load original model from local path
|
||||||
|
loaded = torch.load(pretrained_model_weights_path, map_location=torch.device("cpu"))["model"]
|
||||||
|
# Renaming the original model state dictionary to HF compatibile
|
||||||
|
all_keys = list(loaded.keys())
|
||||||
|
new_keys = convert_old_keys_to_new_keys(all_keys)
|
||||||
|
state_dict = {}
|
||||||
|
for key in all_keys:
|
||||||
|
if "backbone.0.body" in key:
|
||||||
|
new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model._backbone")
|
||||||
|
state_dict[new_key] = loaded[key]
|
||||||
|
# Q, K, V encoder values mapping
|
||||||
|
elif re.search("self_attn.in_proj_(weight|bias)", key):
|
||||||
|
# Dynamically find the layer number
|
||||||
|
pattern = r"layers\.(\d+)\.self_attn\.in_proj_(weight|bias)"
|
||||||
|
match = re.search(pattern, key)
|
||||||
|
if match:
|
||||||
|
layer_num = match.group(1)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Pattern not found in key: {key}")
|
||||||
|
|
||||||
|
in_proj_value = loaded.pop(key)
|
||||||
|
if "weight" in key:
|
||||||
|
state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.weight"] = in_proj_value[:256, :]
|
||||||
|
state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.weight"] = in_proj_value[256:512, :]
|
||||||
|
state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.weight"] = in_proj_value[-256:, :]
|
||||||
|
elif "bias" in key:
|
||||||
|
state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.bias"] = in_proj_value[:256]
|
||||||
|
state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.bias"] = in_proj_value[256:512]
|
||||||
|
state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.bias"] = in_proj_value[-256:]
|
||||||
|
else:
|
||||||
|
new_key = new_keys[key]
|
||||||
|
state_dict[new_key] = loaded[key]
|
||||||
|
|
||||||
|
del loaded
|
||||||
|
gc.collect()
|
||||||
|
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
|
||||||
|
prefix = "model."
|
||||||
|
for key in state_dict.copy().keys():
|
||||||
|
if not key.startswith("class_embed") and not key.startswith("bbox_predictor"):
|
||||||
|
val = state_dict.pop(key)
|
||||||
|
state_dict[prefix + key] = val
|
||||||
|
# finally, create HuggingFace model and load state dict
|
||||||
|
model = DabDetrForObjectDetection(config)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.eval()
|
||||||
|
logger.info(f"Saving PyTorch model to {pytorch_dump_folder_path}...")
|
||||||
|
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||||
|
model.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
model.push_to_hub(repo_id=model_name, commit_message="Add new model")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dab_detr_checkpoint(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub):
|
||||||
|
logger.info("Converting image processor...")
|
||||||
|
write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub)
|
||||||
|
|
||||||
|
logger.info(f"Converting model {model_name}...")
|
||||||
|
write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
default="dab-detr-resnet-50",
|
||||||
|
type=str,
|
||||||
|
help="Name of the DAB_DETR model you'd like to convert.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_model_weights_path",
|
||||||
|
default="modelzoo/R50/checkpoint.pth",
|
||||||
|
type=str,
|
||||||
|
help="The path of the original model weights like: modelzoo/checkpoint.pth",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_folder_path", default="DAB_DETR", type=str, help="Path to the folder to output PyTorch model."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--push_to_hub",
|
||||||
|
default=True,
|
||||||
|
type=bool,
|
||||||
|
help="Whether to upload the converted weights and image processor config to the HuggingFace model profile. Default is set to false.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_dab_detr_checkpoint(
|
||||||
|
args.model_name, args.pretrained_model_weights_path, args.pytorch_dump_folder_path, args.push_to_hub
|
||||||
|
)
|
1716
src/transformers/models/dab_detr/modeling_dab_detr.py
Normal file
1716
src/transformers/models/dab_detr/modeling_dab_detr.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -52,7 +52,7 @@ class DetrConfig(PretrainedConfig):
|
|||||||
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can
|
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetrModel`] can
|
||||||
detect in a single image. For COCO, we recommend 100 queries.
|
detect in a single image. For COCO, we recommend 100 queries.
|
||||||
d_model (`int`, *optional*, defaults to 256):
|
d_model (`int`, *optional*, defaults to 256):
|
||||||
Dimension of the layers.
|
This parameter is a general dimension parameter, defining dimensions for components such as the encoder layer and projection parameters in the decoder layer, among others.
|
||||||
encoder_layers (`int`, *optional*, defaults to 6):
|
encoder_layers (`int`, *optional*, defaults to 6):
|
||||||
Number of encoder layers.
|
Number of encoder layers.
|
||||||
decoder_layers (`int`, *optional*, defaults to 6):
|
decoder_layers (`int`, *optional*, defaults to 6):
|
||||||
|
@ -2482,6 +2482,27 @@ class CvtPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DabDetrForObjectDetection(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DabDetrModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DabDetrPreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class DacModel(metaclass=DummyObject):
|
class DacModel(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
0
tests/models/dab_detr/__init__.py
Normal file
0
tests/models/dab_detr/__init__.py
Normal file
839
tests/models/dab_detr/test_modeling_dab_detr.py
Normal file
839
tests/models/dab_detr/test_modeling_dab_detr.py
Normal file
@ -0,0 +1,839 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Testing suite for the PyTorch DAB-DETR model."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import math
|
||||||
|
import unittest
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from transformers import DabDetrConfig, ResNetConfig, is_torch_available, is_vision_available
|
||||||
|
from transformers.testing_utils import require_timm, require_torch, require_vision, slow, torch_device
|
||||||
|
from transformers.utils import cached_property
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
DabDetrForObjectDetection,
|
||||||
|
DabDetrModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from transformers import ConditionalDetrImageProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class DabDetrModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=8,
|
||||||
|
is_training=True,
|
||||||
|
use_labels=True,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=8,
|
||||||
|
intermediate_size=4,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
num_queries=12,
|
||||||
|
num_channels=3,
|
||||||
|
min_size=200,
|
||||||
|
max_size=200,
|
||||||
|
n_targets=8,
|
||||||
|
num_labels=91,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.num_queries = num_queries
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.min_size = min_size
|
||||||
|
self.max_size = max_size
|
||||||
|
self.n_targets = n_targets
|
||||||
|
self.num_labels = num_labels
|
||||||
|
|
||||||
|
# we also set the expected seq length for both encoder and decoder
|
||||||
|
self.encoder_seq_length = math.ceil(self.min_size / 32) * math.ceil(self.max_size / 32)
|
||||||
|
self.decoder_seq_length = self.num_queries
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size])
|
||||||
|
|
||||||
|
pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device)
|
||||||
|
|
||||||
|
labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
# labels is a list of Dict (each Dict being the labels for a given example in the batch)
|
||||||
|
labels = []
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
target = {}
|
||||||
|
target["class_labels"] = torch.randint(
|
||||||
|
high=self.num_labels, size=(self.n_targets,), device=torch_device
|
||||||
|
)
|
||||||
|
target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device)
|
||||||
|
target["masks"] = torch.rand(self.n_targets, self.min_size, self.max_size, device=torch_device)
|
||||||
|
labels.append(target)
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
return config, pixel_values, pixel_mask, labels
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
resnet_config = ResNetConfig(
|
||||||
|
num_channels=3,
|
||||||
|
embeddings_size=10,
|
||||||
|
hidden_sizes=[10, 20, 30, 40],
|
||||||
|
depths=[1, 1, 2, 1],
|
||||||
|
hidden_act="relu",
|
||||||
|
num_labels=3,
|
||||||
|
out_features=["stage2", "stage3", "stage4"],
|
||||||
|
out_indices=[2, 3, 4],
|
||||||
|
)
|
||||||
|
return DabDetrConfig(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
encoder_layers=self.num_hidden_layers,
|
||||||
|
decoder_layers=self.num_hidden_layers,
|
||||||
|
encoder_attention_heads=self.num_attention_heads,
|
||||||
|
decoder_attention_heads=self.num_attention_heads,
|
||||||
|
encoder_ffn_dim=self.intermediate_size,
|
||||||
|
decoder_ffn_dim=self.intermediate_size,
|
||||||
|
dropout=self.hidden_dropout_prob,
|
||||||
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
|
num_queries=self.num_queries,
|
||||||
|
num_labels=self.num_labels,
|
||||||
|
use_timm_backbone=False,
|
||||||
|
backbone_config=resnet_config,
|
||||||
|
backbone=None,
|
||||||
|
use_pretrained_backbone=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs()
|
||||||
|
inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def create_and_check_dab_detr_model(self, config, pixel_values, pixel_mask, labels):
|
||||||
|
model = DabDetrModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||||
|
result = model(pixel_values)
|
||||||
|
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.last_hidden_state.shape, (self.batch_size, self.decoder_seq_length, self.hidden_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_dab_detr_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
|
||||||
|
model = DabDetrForObjectDetection(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||||
|
result = model(pixel_values)
|
||||||
|
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||||
|
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||||
|
|
||||||
|
result = model(pixel_values=pixel_values, labels=labels)
|
||||||
|
|
||||||
|
self.parent.assertEqual(result.loss.shape, ())
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||||
|
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class DabDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
DabDetrModel,
|
||||||
|
DabDetrForObjectDetection,
|
||||||
|
)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{
|
||||||
|
"image-feature-extraction": DabDetrModel,
|
||||||
|
"object-detection": DabDetrForObjectDetection,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
is_encoder_decoder = True
|
||||||
|
test_torchscript = False
|
||||||
|
test_pruning = False
|
||||||
|
test_head_masking = False
|
||||||
|
test_missing_keys = False
|
||||||
|
zero_init_hidden_state = True
|
||||||
|
|
||||||
|
# special case for head models
|
||||||
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
|
|
||||||
|
if return_labels:
|
||||||
|
if model_class.__name__ in ["DabDetrForObjectDetection"]:
|
||||||
|
labels = []
|
||||||
|
for i in range(self.model_tester.batch_size):
|
||||||
|
target = {}
|
||||||
|
target["class_labels"] = torch.ones(
|
||||||
|
size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long
|
||||||
|
)
|
||||||
|
target["boxes"] = torch.ones(
|
||||||
|
self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float
|
||||||
|
)
|
||||||
|
target["masks"] = torch.ones(
|
||||||
|
self.model_tester.n_targets,
|
||||||
|
self.model_tester.min_size,
|
||||||
|
self.model_tester.max_size,
|
||||||
|
device=torch_device,
|
||||||
|
dtype=torch.float,
|
||||||
|
)
|
||||||
|
labels.append(target)
|
||||||
|
inputs_dict["labels"] = labels
|
||||||
|
|
||||||
|
return inputs_dict
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = DabDetrModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=DabDetrConfig, has_text_modality=False)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_dab_detr_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_dab_detr_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_dab_detr_object_detection_head_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_dab_detr_object_detection_head_model(*config_and_inputs)
|
||||||
|
|
||||||
|
# TODO: check if this works again for PyTorch 2.x.y
|
||||||
|
@unittest.skip(reason="Got `CUDA error: misaligned address` with PyTorch 2.0.0.")
|
||||||
|
def test_multi_gpu_data_parallel_forward(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="DETR does not use inputs_embeds")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="DETR does not use inputs_embeds")
|
||||||
|
def test_model_get_set_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="DETR does not use inputs_embeds")
|
||||||
|
def test_inputs_embeds_matches_input_ids(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="DETR does not have a get_input_embeddings method")
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="DETR is not a generative model")
|
||||||
|
def test_generate_without_input_ids(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="DETR does not use token embeddings")
|
||||||
|
def test_resize_tokens_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_outputs_equivalence(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
def set_nan_tensor_to_zero(t):
|
||||||
|
print(t)
|
||||||
|
t[t != t] = 0
|
||||||
|
return t
|
||||||
|
|
||||||
|
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||||
|
with torch.no_grad():
|
||||||
|
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||||
|
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||||
|
|
||||||
|
def recursive_check(tuple_object, dict_object):
|
||||||
|
if isinstance(tuple_object, (List, Tuple)):
|
||||||
|
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||||
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||||
|
elif isinstance(tuple_object, Dict):
|
||||||
|
for tuple_iterable_value, dict_iterable_value in zip(
|
||||||
|
tuple_object.values(), dict_object.values()
|
||||||
|
):
|
||||||
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||||
|
elif tuple_object is None:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
torch.testing.assert_close(
|
||||||
|
set_nan_tensor_to_zero(tuple_object),
|
||||||
|
set_nan_tensor_to_zero(dict_object),
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=1e-5,
|
||||||
|
msg=(
|
||||||
|
"Tuple and dict output are not equal. Difference:"
|
||||||
|
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||||
|
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||||
|
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
recursive_check(tuple_output, dict_output)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
|
if self.has_attentions:
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
check_equivalence(
|
||||||
|
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||||
|
|
||||||
|
expected_num_layers = getattr(
|
||||||
|
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||||
|
|
||||||
|
if hasattr(self.model_tester, "encoder_seq_length"):
|
||||||
|
seq_length = self.model_tester.encoder_seq_length
|
||||||
|
if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
|
||||||
|
seq_length = seq_length * self.model_tester.chunk_length
|
||||||
|
else:
|
||||||
|
seq_length = self.model_tester.seq_length
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
[hidden_states[0].shape[1], hidden_states[0].shape[2]],
|
||||||
|
[seq_length, self.model_tester.hidden_size],
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.is_encoder_decoder:
|
||||||
|
hidden_states = outputs.decoder_hidden_states
|
||||||
|
|
||||||
|
self.assertIsInstance(hidden_states, (list, tuple))
|
||||||
|
|
||||||
|
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||||
|
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||||
|
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
[hidden_states[0].shape[1], hidden_states[0].shape[2]],
|
||||||
|
[decoder_seq_length, self.model_tester.hidden_size],
|
||||||
|
)
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs_dict["output_hidden_states"] = True
|
||||||
|
check_hidden_states_output(inputs_dict, config, model_class)
|
||||||
|
|
||||||
|
# check that output_hidden_states also work using config
|
||||||
|
del inputs_dict["output_hidden_states"]
|
||||||
|
config.output_hidden_states = True
|
||||||
|
|
||||||
|
check_hidden_states_output(inputs_dict, config, model_class)
|
||||||
|
|
||||||
|
# Had to modify the threshold to 2 decimals instead of 3 because sometimes it threw an error
|
||||||
|
def test_batching_equivalence(self):
|
||||||
|
"""
|
||||||
|
Tests that the model supports batching and that the output is the nearly the same for the same input in
|
||||||
|
different batch sizes.
|
||||||
|
(Why "nearly the same" not "exactly the same"? Batching uses different matmul shapes, which often leads to
|
||||||
|
different results: https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_tensor_equivalence_function(batched_input):
|
||||||
|
# models operating on continuous spaces have higher abs difference than LMs
|
||||||
|
# instead, we can rely on cos distance for image/speech models, similar to `diffusers`
|
||||||
|
if "input_ids" not in batched_input:
|
||||||
|
return lambda tensor1, tensor2: (
|
||||||
|
1.0 - F.cosine_similarity(tensor1.float().flatten(), tensor2.float().flatten(), dim=0, eps=1e-38)
|
||||||
|
)
|
||||||
|
return lambda tensor1, tensor2: torch.max(torch.abs(tensor1 - tensor2))
|
||||||
|
|
||||||
|
def recursive_check(batched_object, single_row_object, model_name, key):
|
||||||
|
if isinstance(batched_object, (list, tuple)):
|
||||||
|
for batched_object_value, single_row_object_value in zip(batched_object, single_row_object):
|
||||||
|
recursive_check(batched_object_value, single_row_object_value, model_name, key)
|
||||||
|
elif isinstance(batched_object, dict):
|
||||||
|
for batched_object_value, single_row_object_value in zip(
|
||||||
|
batched_object.values(), single_row_object.values()
|
||||||
|
):
|
||||||
|
recursive_check(batched_object_value, single_row_object_value, model_name, key)
|
||||||
|
# do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects
|
||||||
|
elif batched_object is None or not isinstance(batched_object, torch.Tensor):
|
||||||
|
return
|
||||||
|
elif batched_object.dim() == 0:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# indexing the first element does not always work
|
||||||
|
# e.g. models that output similarity scores of size (N, M) would need to index [0, 0]
|
||||||
|
slice_ids = [slice(0, index) for index in single_row_object.shape]
|
||||||
|
batched_row = batched_object[slice_ids]
|
||||||
|
self.assertFalse(
|
||||||
|
torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}"
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}"
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}"
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}"
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
(equivalence(batched_row, single_row_object)) <= 1e-02,
|
||||||
|
msg=(
|
||||||
|
f"Batched and Single row outputs are not equal in {model_name} for key={key}. "
|
||||||
|
f"Difference={equivalence(batched_row, single_row_object)}."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
equivalence = get_tensor_equivalence_function(batched_input)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config.output_hidden_states = True
|
||||||
|
|
||||||
|
model_name = model_class.__name__
|
||||||
|
if hasattr(self.model_tester, "prepare_config_and_inputs_for_model_class"):
|
||||||
|
config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class)
|
||||||
|
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
|
batch_size = self.model_tester.batch_size
|
||||||
|
single_row_input = {}
|
||||||
|
for key, value in batched_input_prepared.items():
|
||||||
|
if isinstance(value, torch.Tensor) and value.shape[0] % batch_size == 0:
|
||||||
|
# e.g. musicgen has inputs of size (bs*codebooks). in most cases value.shape[0] == batch_size
|
||||||
|
single_batch_shape = value.shape[0] // batch_size
|
||||||
|
single_row_input[key] = value[:single_batch_shape]
|
||||||
|
else:
|
||||||
|
single_row_input[key] = value
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model_batched_output = model(**batched_input_prepared)
|
||||||
|
model_row_output = model(**single_row_input)
|
||||||
|
|
||||||
|
if isinstance(model_batched_output, torch.Tensor):
|
||||||
|
model_batched_output = {"model_output": model_batched_output}
|
||||||
|
model_row_output = {"model_output": model_row_output}
|
||||||
|
|
||||||
|
for key in model_batched_output:
|
||||||
|
# DETR starts from zero-init queries to decoder, leading to cos_similarity = `nan`
|
||||||
|
if hasattr(self, "zero_init_hidden_state") and "decoder_hidden_states" in key:
|
||||||
|
model_batched_output[key] = model_batched_output[key][1:]
|
||||||
|
model_row_output[key] = model_row_output[key][1:]
|
||||||
|
recursive_check(model_batched_output[key], model_row_output[key], model_name, key)
|
||||||
|
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.return_dict = True
|
||||||
|
|
||||||
|
decoder_seq_length = self.model_tester.decoder_seq_length
|
||||||
|
encoder_seq_length = self.model_tester.encoder_seq_length
|
||||||
|
decoder_key_length = self.model_tester.decoder_seq_length
|
||||||
|
encoder_key_length = self.model_tester.encoder_seq_length
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs_dict["output_attentions"] = True
|
||||||
|
inputs_dict["output_hidden_states"] = False
|
||||||
|
config.return_dict = True
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
|
# check that output_attentions also work using config
|
||||||
|
del inputs_dict["output_attentions"]
|
||||||
|
del inputs_dict["output_hidden_states"]
|
||||||
|
config.output_attentions = True
|
||||||
|
config.output_hidden_states = False
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(attentions[0].shape[-3:]),
|
||||||
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
|
)
|
||||||
|
out_len = len(outputs)
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
correct_outlen = 6
|
||||||
|
|
||||||
|
# loss is at first position
|
||||||
|
if "labels" in inputs_dict:
|
||||||
|
correct_outlen += 1 # loss is added to beginning
|
||||||
|
if "past_key_values" in outputs:
|
||||||
|
correct_outlen += 1 # past_key_values have been returned
|
||||||
|
|
||||||
|
self.assertEqual(out_len, correct_outlen)
|
||||||
|
|
||||||
|
# decoder attentions
|
||||||
|
decoder_attentions = outputs.decoder_attentions
|
||||||
|
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||||
|
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||||
|
self.assertListEqual(
|
||||||
|
list(decoder_attentions[0].shape[-3:]),
|
||||||
|
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||||
|
)
|
||||||
|
|
||||||
|
# cross attentions
|
||||||
|
cross_attentions = outputs.cross_attentions
|
||||||
|
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||||
|
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||||
|
self.assertListEqual(
|
||||||
|
list(cross_attentions[0].shape[-3:]),
|
||||||
|
[
|
||||||
|
self.model_tester.num_attention_heads,
|
||||||
|
decoder_seq_length,
|
||||||
|
encoder_key_length,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check attention is always last and order is fine
|
||||||
|
inputs_dict["output_attentions"] = True
|
||||||
|
inputs_dict["output_hidden_states"] = True
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||||
|
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||||
|
elif self.is_encoder_decoder:
|
||||||
|
# decoder_hidden_states, encoder_last_hidden_state, encoder_hidden_states
|
||||||
|
added_hidden_states = 3
|
||||||
|
else:
|
||||||
|
added_hidden_states = 1
|
||||||
|
|
||||||
|
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||||
|
|
||||||
|
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||||
|
|
||||||
|
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||||
|
self.assertListEqual(
|
||||||
|
list(self_attentions[0].shape[-3:]),
|
||||||
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
# removed retain_grad and grad on decoder_hidden_states, as queries don't require grad
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
# no need to test all models as different heads yield the same functionality
|
||||||
|
model_class = self.all_model_classes[0]
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
outputs = model(**inputs, output_attentions=True, output_hidden_states=True)
|
||||||
|
|
||||||
|
# logits
|
||||||
|
output = outputs[0]
|
||||||
|
|
||||||
|
encoder_hidden_states = outputs.encoder_hidden_states[0]
|
||||||
|
encoder_hidden_states.retain_grad()
|
||||||
|
|
||||||
|
encoder_attentions = outputs.encoder_attentions[0]
|
||||||
|
encoder_attentions.retain_grad()
|
||||||
|
|
||||||
|
decoder_attentions = outputs.decoder_attentions[0]
|
||||||
|
decoder_attentions.retain_grad()
|
||||||
|
|
||||||
|
cross_attentions = outputs.cross_attentions[0]
|
||||||
|
cross_attentions.retain_grad()
|
||||||
|
|
||||||
|
output.flatten()[0].backward(retain_graph=True)
|
||||||
|
|
||||||
|
self.assertIsNotNone(encoder_hidden_states.grad)
|
||||||
|
self.assertIsNotNone(encoder_attentions.grad)
|
||||||
|
self.assertIsNotNone(decoder_attentions.grad)
|
||||||
|
self.assertIsNotNone(cross_attentions.grad)
|
||||||
|
|
||||||
|
def test_forward_auxiliary_loss(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.auxiliary_loss = True
|
||||||
|
|
||||||
|
# only test for object detection and segmentation model
|
||||||
|
for model_class in self.all_model_classes[1:]:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
self.assertIsNotNone(outputs.auxiliary_outputs)
|
||||||
|
self.assertEqual(len(outputs.auxiliary_outputs), self.model_tester.num_hidden_layers - 1)
|
||||||
|
|
||||||
|
def test_training(self):
|
||||||
|
if not self.model_tester.is_training:
|
||||||
|
self.skipTest(reason="ModelTester is not configured to run training tests")
|
||||||
|
|
||||||
|
# We only have loss with ObjectDetection
|
||||||
|
model_class = self.all_model_classes[-1]
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.return_dict = True
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
loss = model(**inputs).loss
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
signature = inspect.signature(model.forward)
|
||||||
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
expected_arg_names = ["pixel_values", "pixel_mask"]
|
||||||
|
expected_arg_names.extend(
|
||||||
|
["head_mask", "decoder_head_mask", "encoder_outputs"]
|
||||||
|
if "head_mask" and "decoder_head_mask" in arg_names
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||||
|
else:
|
||||||
|
expected_arg_names = ["pixel_values", "pixel_mask"]
|
||||||
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
|
|
||||||
|
def test_different_timm_backbone(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
# let's pick a random timm backbone
|
||||||
|
config.backbone = "tf_mobilenetv3_small_075"
|
||||||
|
config.backbone_config = None
|
||||||
|
config.use_timm_backbone = True
|
||||||
|
config.backbone_kwargs = {"out_indices": [2, 3, 4]}
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
if model_class.__name__ == "DabDetrForObjectDetection":
|
||||||
|
expected_shape = (
|
||||||
|
self.model_tester.batch_size,
|
||||||
|
self.model_tester.num_queries,
|
||||||
|
self.model_tester.num_labels,
|
||||||
|
)
|
||||||
|
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||||
|
# Confirm out_indices was propogated to backbone
|
||||||
|
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
|
||||||
|
else:
|
||||||
|
# Confirm out_indices was propogated to backbone
|
||||||
|
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)
|
||||||
|
|
||||||
|
self.assertTrue(outputs)
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
configs_no_init.init_xavier_std = 1e9
|
||||||
|
# Copied from RT-DETR
|
||||||
|
configs_no_init.initializer_bias_prior_prob = 0.2
|
||||||
|
bias_value = -1.3863 # log_e ((1 - 0.2) / 0.2)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
if "bbox_attention" in name and "bias" not in name:
|
||||||
|
self.assertLess(
|
||||||
|
100000,
|
||||||
|
abs(param.data.max().item()),
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
# Modifed from RT-DETR
|
||||||
|
elif "class_embed" in name and "bias" in name:
|
||||||
|
bias_tensor = torch.full_like(param.data, bias_value)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
param.data,
|
||||||
|
bias_tensor,
|
||||||
|
atol=1e-4,
|
||||||
|
rtol=1e-4,
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
elif "activation_fn" in name and config.activation_function == "prelu":
|
||||||
|
self.assertTrue(
|
||||||
|
param.data.mean() == 0.25,
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
elif "backbone.conv_encoder.model" in name:
|
||||||
|
continue
|
||||||
|
elif "self_attn.in_proj_weight" in name:
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e2).round() / 1e2).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TOLERANCE = 1e-4
|
||||||
|
CHECKPOINT = "IDEA-Research/dab-detr-resnet-50"
|
||||||
|
|
||||||
|
|
||||||
|
# We will verify our results on an image of cute cats
|
||||||
|
def prepare_img():
|
||||||
|
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
@require_timm
|
||||||
|
@require_vision
|
||||||
|
@slow
|
||||||
|
class DabDetrModelIntegrationTests(unittest.TestCase):
|
||||||
|
@cached_property
|
||||||
|
def default_image_processor(self):
|
||||||
|
return ConditionalDetrImageProcessor.from_pretrained(CHECKPOINT) if is_vision_available() else None
|
||||||
|
|
||||||
|
def test_inference_no_head(self):
|
||||||
|
model = DabDetrModel.from_pretrained(CHECKPOINT).to(torch_device)
|
||||||
|
|
||||||
|
image_processor = self.default_image_processor
|
||||||
|
image = prepare_img()
|
||||||
|
encoding = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values=encoding.pixel_values)
|
||||||
|
|
||||||
|
expected_shape = torch.Size((1, 300, 256))
|
||||||
|
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[[-0.4879, -0.2594, 0.4524], [-0.4997, -0.4258, 0.4329], [-0.8220, -0.4996, 0.0577]]
|
||||||
|
).to(torch_device)
|
||||||
|
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=2e-4, rtol=2e-4)
|
||||||
|
|
||||||
|
def test_inference_object_detection_head(self):
|
||||||
|
model = DabDetrForObjectDetection.from_pretrained(CHECKPOINT).to(torch_device)
|
||||||
|
|
||||||
|
image_processor = self.default_image_processor
|
||||||
|
image = prepare_img()
|
||||||
|
encoding = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||||
|
pixel_values = encoding["pixel_values"].to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values)
|
||||||
|
|
||||||
|
# verify logits + box predictions
|
||||||
|
expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels))
|
||||||
|
self.assertEqual(outputs.logits.shape, expected_shape_logits)
|
||||||
|
expected_slice_logits = torch.tensor(
|
||||||
|
[[-10.1765, -5.5243, -8.9324], [-9.8138, -5.6721, -7.5161], [-10.3054, -5.6081, -8.5931]]
|
||||||
|
).to(torch_device)
|
||||||
|
torch.testing.assert_close(outputs.logits[0, :3, :3], expected_slice_logits, atol=3e-4, rtol=3e-4)
|
||||||
|
|
||||||
|
expected_shape_boxes = torch.Size((1, model.config.num_queries, 4))
|
||||||
|
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
|
||||||
|
expected_slice_boxes = torch.tensor(
|
||||||
|
[[0.3708, 0.3000, 0.2753], [0.5211, 0.6125, 0.9495], [0.2897, 0.6730, 0.5459]]
|
||||||
|
).to(torch_device)
|
||||||
|
torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4, rtol=1e-4)
|
||||||
|
|
||||||
|
# verify postprocessing
|
||||||
|
results = image_processor.post_process_object_detection(
|
||||||
|
outputs, threshold=0.3, target_sizes=[image.size[::-1]]
|
||||||
|
)[0]
|
||||||
|
expected_scores = torch.tensor([0.8732, 0.8563, 0.8554, 0.6079, 0.5896]).to(torch_device)
|
||||||
|
expected_labels = [17, 75, 17, 75, 63]
|
||||||
|
expected_boxes = torch.tensor([14.6970, 49.3892, 320.5165, 469.2765]).to(torch_device)
|
||||||
|
|
||||||
|
self.assertEqual(len(results["scores"]), 5)
|
||||||
|
torch.testing.assert_close(results["scores"], expected_scores, atol=1e-4, rtol=1e-4)
|
||||||
|
self.assertSequenceEqual(results["labels"].tolist(), expected_labels)
|
||||||
|
torch.testing.assert_close(results["boxes"][0, :], expected_boxes, atol=1e-4, rtol=1e-4)
|
@ -161,6 +161,16 @@ SPECIAL_CASES_TO_ALLOW = {
|
|||||||
"giou_loss_coefficient",
|
"giou_loss_coefficient",
|
||||||
"mask_loss_coefficient",
|
"mask_loss_coefficient",
|
||||||
],
|
],
|
||||||
|
"DabDetrConfig": [
|
||||||
|
"dilation",
|
||||||
|
"bbox_cost",
|
||||||
|
"bbox_loss_coefficient",
|
||||||
|
"class_cost",
|
||||||
|
"cls_loss_coefficient",
|
||||||
|
"focal_alpha",
|
||||||
|
"giou_cost",
|
||||||
|
"giou_loss_coefficient",
|
||||||
|
],
|
||||||
"DetrConfig": [
|
"DetrConfig": [
|
||||||
"bbox_cost",
|
"bbox_cost",
|
||||||
"bbox_loss_coefficient",
|
"bbox_loss_coefficient",
|
||||||
|
Loading…
Reference in New Issue
Block a user