mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-16 02:58:23 +06:00

* initial commit * encoder+decoder layer changes WIP * architecture checks * working version of detection + segmentation * fix modeling outputs * fix return dict + output att/hs * found the position embedding masking bug * pre-training version * added iamge processors * typo in init.py * iterupdate set to false * fixed num_labels in class_output linear layer bias init * multihead attention shape fixes * test improvements * test update * dab-detr model_doc update * dab-detr model_doc update2 * test fix:test_retain_grad_hidden_states_attentions * config file clean and renaming variables * config file clean and renaming variables fix * updated convert_to_hf file * small fixes * style and qulity checks * return_dict fix * Merge branch main into add_dab_detr * small comment fix * skip test_inputs_embeds test * image processor updates + image processor test updates * check copies test fix update * updates for check_copies.py test * updates for check_copies.py test2 * tied weights fix * fixed image processing tests and fixed shared weights issues * added numpy nd array option to get_Expected_values method in test_image_processing_dab_detr.py * delete prints from test file * SafeTensor modification to solve HF Trainer issue * removing the safetensor modifications * make fix copies and hf uplaod has been added. * fixed index.md * fixed repo consistency * styel fix and dabdetrimageprocessor docstring update * requested modifications after the first review * Update src/transformers/models/dab_detr/image_processing_dab_detr.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * repo consistency has been fixed * update copied NestedTensor function after main merge * Update src/transformers/models/dab_detr/modeling_dab_detr.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * temp commit * temp commit2 * temp commit 3 * unit tests are fixed * fixed repo consistency * updated expected_boxes varible values based on related notebook results in DABDETRIntegrationTests file. * temporarialy config modifications and repo consistency fixes * Put dilation parameter back to config * pattern embeddings have been added to the rename_keys method * add dilation comment to config + add as an exception in check_config_attributes SPECIAL CASES * delete FeatureExtractor part from docs.md * requested modifications in modeling_dab_detr.py * [run_slow] dab_detr * deleted last segmentation code part, updated conversion script and changed the hf path in test files * temp commit of requested modifications * temp commit of requested modifications 2 * updated config file, resolved codepaths and refactored conversion script * updated decodelayer block types and refactored conversion script * style and quality update * small modifications based on the request * attentions are refactored * removed loss functions from modeling file, added loss function to lossutils, tried to move the MLP layer generation to config but it failed * deleted imageprocessor * fixed conversion script + quality and style * fixed config_att * [run_slow] dab_detr * changing model path in conversion file and in test file * fix Decoder variable naming * testing the old loss function * switched back to the new loss function and testing with the odl attention functions * switched back to the new last good result modeling file * moved back to the version when I asked the review * missing new line at the end of the file * old version test * turn back to newest mdoel versino but change image processor * style fix * style fix after merge main * [run_slow] dab_detr * [run_slow] dab_detr * added device and type for head bias data part * [run_slow] dab_detr * fixed model head bias data fill * changed test_inference_object_detection_head assertTrues to torch test assert_close * fixes part 1 * quality update * self.bbox_embed in decoder has been restored * changed Assert true torch closeall methods to torch testing assertclose * modelcard markdown file has been updated * deleted intemediate list from decoder module --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
120 lines
4.9 KiB
Markdown
120 lines
4.9 KiB
Markdown
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
|
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
|
rendered properly in your Markdown viewer.
|
|
|
|
-->
|
|
|
|
# DAB-DETR
|
|
|
|
## Overview
|
|
|
|
The DAB-DETR model was proposed in [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](https://arxiv.org/abs/2201.12329) by Shilong Liu, Feng Li, Hao Zhang, Xiao Yang, Xianbiao Qi, Hang Su, Jun Zhu, Lei Zhang.
|
|
DAB-DETR is an enhanced variant of Conditional DETR. It utilizes dynamically updated anchor boxes to provide both a reference query point (x, y) and a reference anchor size (w, h), improving cross-attention computation. This new approach achieves 45.7% AP when trained for 50 epochs with a single ResNet-50 model as the backbone.
|
|
|
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dab_detr_convergence_plot.png"
|
|
alt="drawing" width="600"/>
|
|
|
|
The abstract from the paper is the following:
|
|
|
|
*We present in this paper a novel query formulation using dynamic anchor boxes
|
|
for DETR (DEtection TRansformer) and offer a deeper understanding of the role
|
|
of queries in DETR. This new formulation directly uses box coordinates as queries
|
|
in Transformer decoders and dynamically updates them layer-by-layer. Using box
|
|
coordinates not only helps using explicit positional priors to improve the query-to-feature similarity and eliminate the slow training convergence issue in DETR,
|
|
but also allows us to modulate the positional attention map using the box width
|
|
and height information. Such a design makes it clear that queries in DETR can be
|
|
implemented as performing soft ROI pooling layer-by-layer in a cascade manner.
|
|
As a result, it leads to the best performance on MS-COCO benchmark among
|
|
the DETR-like detection models under the same setting, e.g., AP 45.7% using
|
|
ResNet50-DC5 as backbone trained in 50 epochs. We also conducted extensive
|
|
experiments to confirm our analysis and verify the effectiveness of our methods.*
|
|
|
|
This model was contributed by [davidhajdu](https://huggingface.co/davidhajdu).
|
|
The original code can be found [here](https://github.com/IDEA-Research/DAB-DETR).
|
|
|
|
## How to Get Started with the Model
|
|
|
|
Use the code below to get started with the model.
|
|
|
|
```python
|
|
import torch
|
|
import requests
|
|
|
|
from PIL import Image
|
|
from transformers import AutoModelForObjectDetection, AutoImageProcessor
|
|
|
|
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
|
image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
image_processor = AutoImageProcessor.from_pretrained("IDEA-Research/dab-detr-resnet-50")
|
|
model = AutoModelForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50")
|
|
|
|
inputs = image_processor(images=image, return_tensors="pt")
|
|
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
|
|
results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3)
|
|
|
|
for result in results:
|
|
for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
|
|
score, label = score.item(), label_id.item()
|
|
box = [round(i, 2) for i in box.tolist()]
|
|
print(f"{model.config.id2label[label]}: {score:.2f} {box}")
|
|
```
|
|
This should output
|
|
```
|
|
cat: 0.87 [14.7, 49.39, 320.52, 469.28]
|
|
remote: 0.86 [41.08, 72.37, 173.39, 117.2]
|
|
cat: 0.86 [344.45, 19.43, 639.85, 367.86]
|
|
remote: 0.61 [334.27, 75.93, 367.92, 188.81]
|
|
couch: 0.59 [-0.04, 1.34, 639.9, 477.09]
|
|
```
|
|
|
|
There are three other ways to instantiate a DAB-DETR model (depending on what you prefer):
|
|
|
|
Option 1: Instantiate DAB-DETR with pre-trained weights for entire model
|
|
```py
|
|
>>> from transformers import DabDetrForObjectDetection
|
|
|
|
>>> model = DabDetrForObjectDetection.from_pretrained("IDEA-Research/dab-detr-resnet-50")
|
|
```
|
|
|
|
Option 2: Instantiate DAB-DETR with randomly initialized weights for Transformer, but pre-trained weights for backbone
|
|
```py
|
|
>>> from transformers import DabDetrConfig, DabDetrForObjectDetection
|
|
|
|
>>> config = DabDetrConfig()
|
|
>>> model = DabDetrForObjectDetection(config)
|
|
```
|
|
Option 3: Instantiate DAB-DETR with randomly initialized weights for backbone + Transformer
|
|
```py
|
|
>>> config = DabDetrConfig(use_pretrained_backbone=False)
|
|
>>> model = DabDetrForObjectDetection(config)
|
|
```
|
|
|
|
|
|
## DabDetrConfig
|
|
|
|
[[autodoc]] DabDetrConfig
|
|
|
|
## DabDetrModel
|
|
|
|
[[autodoc]] DabDetrModel
|
|
- forward
|
|
|
|
## DabDetrForObjectDetection
|
|
|
|
[[autodoc]] DabDetrForObjectDetection
|
|
- forward
|