New model support RTDETR (#29077)

* fill out docs string in configuration
75dcd3a0e8 (r1506391856)

* reduce the input image size for the tests

* remove the unappropriate tests

* only 5 failes exists

* make style

* fill up missed architecture for object detection in docs

* fix auto modeling

* simple fix in missing import

* major change including backbone refactor and objectdetectionoutput refactor

* minor fix only 4 fails left

* intermediate fix

* revert __init__.py

* revert __init__.py

* make style

* fixes in pr_docs

* intermediate fix

* make style

* two fixes

* pass doctest

* only one fix left

* intermediate commit

* all fixed

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/rt_detr/test_modeling_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* function class above the model definition in dice_loss

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* simple fix

* layernorm add config.layer_norm_eps

* fix inputs_docstring

* make style

* simple fix

* add custom coco loading test in image_processor

* fix error in BaseModelOutput
https://github.com/huggingface/transformers/pull/29077#discussion_r1516657790

* simple typo

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* intermediate fix

* fix with load_backbone format

* remove unused configuration

* 3 fix test left

* make style

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: Sounak Dey <dey.sounak@gmail.com>

* change last_hidden_state to first index

* all pass fix
TO DO: minor update in comments

* make fix-copies

* remove deepcopy

* pr_document fix

* revert deepcopy due to the issue of unexpceted behavior in decoderlayer

* add atol in final

* add no_split_module

* _no_split_modules = None

* device transfer for model parallelism

* minor fix

* make fix-copies

* fix typo

* add test_image_processor with post_processing

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add config in RTDETRPredictionHead

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set lru_cache with max_size 32

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add lru_cache import and configuration change

* change the order of definition

* make fix-copies

* add docs and change config error

* revert strange make-fix

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* test pass

* fix get_clones related and remove deepcopy

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* nit for paper section

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* rename denoising related parameters

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* check the image transformation logic

* make style

* make style

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* pe_encoding -> positional_encoding_temperature

* remove TODO

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* remove eval_idx since transformer DETR is giving all decoder output

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* change variable name

* make style and docs import update

* Revert "Update src/transformers/models/rt_detr/image_processing_rt_detr.py"

This reverts commit 74aa3e1de0.

* fix typo

* add postprocessing in docs

* move import scipy to top

* change varaible name

* make fix-copies

* remove eval_idx in test

* move to after first sentence

* update image_processor since box loss requires normalized one

* change appropriate name to auxiliary_outputs

* Update src/transformers/models/rt_detr/__init__.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/__init__.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/en/model_doc/rt_detr.md

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/en/model_doc/rt_detr.md

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* make style

* remove panoptic related comments

* make style

* revert valid_processor_keys

* fix aux related test

* make style

* change origination from config to backbone API

* enable the dn_loss

* fix test and conversion

* renewal weight initialization

* change initializer_range

* make fix-up

* fix the loss issue in the auxiliary output and denoising part

* change weight loss to original RTDETR

* fix in initialization

* sync shape format of dn and aux

* make style

* stable fine-tuning and compatible conversion for resnet101

* make style

* skip input_embed

* change encoder related variable

* enable converting rtdetr_r101

* add r101 related conversion code

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/model_doc/rt_detr.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/__init__.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/__init__.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* change name _shape to _reshape

* Update src/transformers/__init__.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/__init__.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* maket style

* make fix-copies

* remove deprecated import

* more fix

* remove last_hidden_state for task-specific model

* Revert "remove last_hidden_state for task-specific model"

This reverts commit ccb7a34051.

* minore change in convert

* remove print

* make style and fix-copies

* add custom rtdetr backbone for r18, r34

* remove print

* change copied

* add pad_size

* make style

* change layertype to optional to pass the CI

* make style

* add test in modeling_resnet_rt_detr

* make fix-copies

* skip tmp file test

* fix comment

* add docs

* change to modeling_resnet file format

* enabling resnet50 above

* Update src/transformers/models/rt_detr/modeling_rt_detr.py

Co-authored-by: Jason Wu <jasonkit@users.noreply.github.com>

* enable all the rtdetr model :)

* finish except CI

* add RTDetrResNetBackbone

* make fix-copies

* fix
TO DO: CI enable

* make style

* rename test

* add docs

* add special fix

* revert resnet

* Update src/transformers/models/rt_detr/modeling_rt_detr_resnet.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* add more comment

* remove swin comment

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* rename convert and add verify backbone

* Update docs/source/en/_toctree.yml

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/en/model_doc/rt_detr.md

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/en/model_doc/rt_detr.md

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* make style

* requests for docs

* more general test docs

* general script docs

* make fix-copies

* final commit

* Revert "Update src/transformers/models/rt_detr/configuration_rt_detr.py"

This reverts commit d136225cd3.

* skip test_model_get_set_embeddings

* remove target

* add changes

* make fix-copies

* remove decoder_attention_mask

* add load_backbone function for auto_backbone

* remove comment

* fix repo name

* Update src/transformers/models/rt_detr/configuration_rt_detr.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* final commit

* remove unused downsample_in_bottleneck

* new test for autobackbone

* change to appropriate indices

* test fix

* fix dict in test_image_processor

* fix test

* [run-slow] rt_detr, rt_detr_resnet

* change the slow test

* [run-slow] rt_detr

* [run-slow] rt_detr, rt_detr_resnet

* make in to same cuda in CSPRepLayer

* [run-slow] rt_detr, rt_detr_resnet

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Sounak Dey <dey.sounak@gmail.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Jason Wu <jasonkit@users.noreply.github.com>
Co-authored-by: ChoiSangBum <choisangbum@ChoiSangBumui-MacBookPro.local>
This commit is contained in:
Sangbum Daniel Choi 2024-06-22 01:50:08 +09:00 committed by GitHub
parent 8b7cd40273
commit 74a207404e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 6898 additions and 15 deletions

View File

@ -627,6 +627,8 @@
title: RegNet
- local: model_doc/resnet
title: ResNet
- local: model_doc/rt_detr
title: RT-DETR
- local: model_doc/segformer
title: SegFormer
- local: model_doc/seggpt

View File

@ -262,6 +262,8 @@ Flax), PyTorch, and/or TensorFlow.
| [RoBERTa-PreLayerNorm](model_doc/roberta-prelayernorm) | ✅ | ✅ | ✅ |
| [RoCBert](model_doc/roc_bert) | ✅ | ❌ | ❌ |
| [RoFormer](model_doc/roformer) | ✅ | ✅ | ✅ |
| [RT-DETR](model_doc/rt_detr) | ✅ | ❌ | ❌ |
| [RT-DETR-ResNet](model_doc/rt_detr_resnet) | ✅ | ❌ | ❌ |
| [RWKV](model_doc/rwkv) | ✅ | ❌ | ❌ |
| [SAM](model_doc/sam) | ✅ | ✅ | ❌ |
| [SeamlessM4T](model_doc/seamless_m4t) | ✅ | ❌ | ❌ |

View File

@ -0,0 +1,85 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# RT-DETR
## Overview
The RT-DETR model was proposed in [DETRs Beat YOLOs on Real-time Object Detection](https://arxiv.org/abs/2304.08069) by Wenyu Lv, Yian Zhao, Shangliang Xu, Jinman Wei, Guanzhong Wang, Cheng Cui, Yuning Du, Qingqing Dang, Yi Liu.
RT-DETR is an object detection model that stands for "Real-Time DEtection Transformer." This model is designed to perform object detection tasks with a focus on achieving real-time performance while maintaining high accuracy. Leveraging the transformer architecture, which has gained significant popularity in various fields of deep learning, RT-DETR processes images to identify and locate multiple objects within them.
The abstract from the paper is the following:
*Recently, end-to-end transformer-based detectors (DETRs) have achieved remarkable performance. However, the issue of the high computational cost of DETRs has not been effectively addressed, limiting their practical application and preventing them from fully exploiting the benefits of no post-processing, such as non-maximum suppression (NMS). In this paper, we first analyze the influence of NMS in modern real-time object detectors on inference speed, and establish an end-to-end speed benchmark. To avoid the inference delay caused by NMS, we propose a Real-Time DEtection TRansformer (RT-DETR), the first real-time end-to-end object detector to our best knowledge. Specifically, we design an efficient hybrid encoder to efficiently process multi-scale features by decoupling the intra-scale interaction and cross-scale fusion, and propose IoU-aware query selection to improve the initialization of object queries. In addition, our proposed detector supports flexibly adjustment of the inference speed by using different decoder layers without the need for retraining, which facilitates the practical application of real-time object detectors. Our RT-DETR-L achieves 53.0% AP on COCO val2017 and 114 FPS on T4 GPU, while RT-DETR-X achieves 54.8% AP and 74 FPS, outperforming all YOLO detectors of the same scale in both speed and accuracy. Furthermore, our RT-DETR-R50 achieves 53.1% AP and 108 FPS, outperforming DINO-Deformable-DETR-R50 by 2.2% AP in accuracy and by about 21 times in FPS.*
The model version was contributed by [rafaelpadilla](https://huggingface.co/rafaelpadilla) and [sangbumchoi](https://github.com/SangbumChoi). The original code can be found [here](https://github.com/lyuwenyu/RT-DETR/).
## Usage tips
Initially, an image is processed using a pre-trained convolutional neural network, specifically a Resnet-D variant as referenced in the original code. This network extracts features from the final three layers of the architecture. Following this, a hybrid encoder is employed to convert the multi-scale features into a sequential array of image features. Then, a decoder, equipped with auxiliary prediction heads is used to refine the object queries. This process facilitates the direct generation of bounding boxes, eliminating the need for any additional post-processing to acquire the logits and coordinates for the bounding boxes.
```py
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
from PIL import Image
import json
import torch
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
inputs = image_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]), threshold=0.3)
```
## RTDetrConfig
[[autodoc]] RTDetrConfig
## RTDetrResNetConfig
[[autodoc]] RTDetrResNetConfig
## RTDetrImageProcessor
[[autodoc]] RTDetrImageProcessor
- preprocess
- post_process_object_detection
## RTDetrModel
[[autodoc]] RTDetrModel
- forward
## RTDetrForObjectDetection
[[autodoc]] RTDetrForObjectDetection
- forward
## RTDetrResNetBackbone
[[autodoc]] RTDetrResNetBackbone
- forward

View File

@ -654,6 +654,7 @@ _import_structure = {
"RoFormerConfig",
"RoFormerTokenizer",
],
"models.rt_detr": ["RTDetrConfig", "RTDetrResNetConfig"],
"models.rwkv": ["RwkvConfig"],
"models.sam": [
"SamConfig",
@ -1153,6 +1154,7 @@ else:
_import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"])
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
_import_structure["models.pvt"].extend(["PvtImageProcessor"])
_import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"])
_import_structure["models.sam"].extend(["SamImageProcessor"])
_import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
_import_structure["models.seggpt"].extend(["SegGptImageProcessor"])
@ -3004,6 +3006,15 @@ else:
"load_tf_weights_in_roformer",
]
)
_import_structure["models.rt_detr"].extend(
[
"RTDetrForObjectDetection",
"RTDetrModel",
"RTDetrPreTrainedModel",
"RTDetrResNetBackbone",
"RTDetrResNetPreTrainedModel",
]
)
_import_structure["models.rwkv"].extend(
[
"RwkvForCausalLM",
@ -5270,6 +5281,10 @@ if TYPE_CHECKING:
RoFormerConfig,
RoFormerTokenizer,
)
from .models.rt_detr import (
RTDetrConfig,
RTDetrResNetConfig,
)
from .models.rwkv import RwkvConfig
from .models.sam import (
SamConfig,
@ -5792,6 +5807,7 @@ if TYPE_CHECKING:
PoolFormerImageProcessor,
)
from .models.pvt import PvtImageProcessor
from .models.rt_detr import RTDetrImageProcessor
from .models.sam import SamImageProcessor
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
from .models.seggpt import SegGptImageProcessor
@ -7295,6 +7311,13 @@ if TYPE_CHECKING:
RoFormerPreTrainedModel,
load_tf_weights_in_roformer,
)
from .models.rt_detr import (
RTDetrForObjectDetection,
RTDetrModel,
RTDetrPreTrainedModel,
RTDetrResNetBackbone,
RTDetrResNetPreTrainedModel,
)
from .models.rwkv import (
RwkvForCausalLM,
RwkvModel,

View File

@ -193,6 +193,7 @@ from . import (
roberta_prelayernorm,
roc_bert,
roformer,
rt_detr,
rwkv,
sam,
seamless_m4t,

View File

@ -214,6 +214,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("roberta-prelayernorm", "RobertaPreLayerNormConfig"),
("roc_bert", "RoCBertConfig"),
("roformer", "RoFormerConfig"),
("rt_detr", "RTDetrConfig"),
("rt_detr_resnet", "RTDetrResNetConfig"),
("rwkv", "RwkvConfig"),
("sam", "SamConfig"),
("seamless_m4t", "SeamlessM4TConfig"),
@ -499,6 +501,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
("roberta-prelayernorm", "RoBERTa-PreLayerNorm"),
("roc_bert", "RoCBert"),
("roformer", "RoFormer"),
("rt_detr", "RT-DETR"),
("rt_detr_resnet", "RT-DETR-ResNet"),
("rwkv", "RWKV"),
("sam", "SAM"),
("seamless_m4t", "SeamlessM4T"),
@ -623,6 +627,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("clip_vision_model", "clip"),
("siglip_vision_model", "siglip"),
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),
]
)

View File

@ -114,6 +114,7 @@ else:
("pvt_v2", ("PvtImageProcessor",)),
("regnet", ("ConvNextImageProcessor",)),
("resnet", ("ConvNextImageProcessor",)),
("rt_detr", "RTDetrImageProcessor"),
("sam", ("SamImageProcessor",)),
("segformer", ("SegformerImageProcessor",)),
("seggpt", ("SegGptImageProcessor",)),

View File

@ -202,6 +202,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("roberta-prelayernorm", "RobertaPreLayerNormModel"),
("roc_bert", "RoCBertModel"),
("roformer", "RoFormerModel"),
("rt_detr", "RTDetrModel"),
("rwkv", "RwkvModel"),
("sam", "SamModel"),
("seamless_m4t", "SeamlessM4TModel"),
@ -765,6 +766,7 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
("deformable_detr", "DeformableDetrForObjectDetection"),
("deta", "DetaForObjectDetection"),
("detr", "DetrForObjectDetection"),
("rt_detr", "RTDetrForObjectDetection"),
("table-transformer", "TableTransformerForObjectDetection"),
("yolos", "YolosForObjectDetection"),
]
@ -1252,6 +1254,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
("nat", "NatBackbone"),
("pvt_v2", "PvtV2Backbone"),
("resnet", "ResNetBackbone"),
("rt_detr_resnet", "RTDetrResNetBackbone"),
("swin", "SwinBackbone"),
("swinv2", "Swinv2Backbone"),
("timm_backbone", "TimmBackbone"),

View File

@ -29,22 +29,24 @@ from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
is_timm_available,
is_torch_cuda_available,
is_vision_available,
replace_return_docstrings,
requires_backends,
)
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_accelerate_available, is_ninja_available, logging
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_ninja_available,
is_scipy_available,
is_timm_available,
is_torch_cuda_available,
is_vision_available,
logging,
replace_return_docstrings,
requires_backends,
)
from ...utils.backbone_utils import load_backbone
from .configuration_deformable_detr import DeformableDetrConfig

View File

@ -0,0 +1,78 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {"configuration_rt_detr": ["RTDetrConfig"], "configuration_rt_detr_resnet": ["RTDetrResNetConfig"]}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_rt_detr"] = ["RTDetrImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_rt_detr"] = [
"RTDetrForObjectDetection",
"RTDetrModel",
"RTDetrPreTrainedModel",
]
_import_structure["modeling_rt_detr_resnet"] = [
"RTDetrResNetBackbone",
"RTDetrResNetPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_rt_detr import RTDetrConfig
from .configuration_rt_detr_resnet import RTDetrResNetConfig
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_rt_detr import RTDetrImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_rt_detr import (
RTDetrForObjectDetection,
RTDetrModel,
RTDetrPreTrainedModel,
)
from .modeling_rt_detr_resnet import (
RTDetrResNetBackbone,
RTDetrResNetPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,352 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RT-DETR model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING
from .configuration_rt_detr_resnet import RTDetrResNetConfig
logger = logging.get_logger(__name__)
class RTDetrConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`RTDetrModel`]. It is used to instantiate a
RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the RT-DETR
[checkpoing/todo](https://huggingface.co/checkpoing/todo) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
initializer_range (`float`, *optional*, defaults to 0.01):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the layer normalization layers.
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the batch normalization layers.
backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
encoder_hidden_dim (`int`, *optional*, defaults to 256):
Dimension of the layers in hybrid encoder.
encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
Multi level features input for encoder.
feat_strides (`List[int]`, *optional*, defaults to `[8, 16, 32]`):
Strides used in each feature map.
encoder_layers (`int`, *optional*, defaults to 1):
Total of layers to be used by the encoder.
encoder_ffn_dim (`int`, *optional*, defaults to 1024):
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
encoder_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer encoder.
dropout (`float`, *optional*, defaults to 0.0):
The ratio for all dropout layers.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
encode_proj_layers (`List[int]`, *optional*, defaults to `[2]`):
Indexes of the projected layers to be used in the encoder.
positional_encoding_temperature (`int`, *optional*, defaults to 10000):
The temperature parameter used to create the positional encodings.
encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
activation_function (`str`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
eval_size (`Tuple[int, int]`, *optional*):
Height and width used to computes the effective height and width of the position embeddings after taking
into account the stride.
normalize_before (`bool`, *optional*, defaults to `False`):
Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
feed-forward modules.
hidden_expansion (`float`, *optional*, defaults to 1.0):
Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
d_model (`int`, *optional*, defaults to 256):
Dimension of the layers exclude hybrid encoder.
num_queries (`int`, *optional*, defaults to 300):
Number of object queries.
decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
Multi level features dimension for decoder
decoder_ffn_dim (`int`, *optional*, defaults to 1024):
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
num_feature_levels (`int`, *optional*, defaults to 3):
The number of input feature levels.
decoder_n_points (`int`, *optional*, defaults to 4):
The number of sampled keys in each feature level for each attention head in the decoder.
decoder_layers (`int`, *optional*, defaults to 6):
Number of decoder layers.
decoder_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
num_denoising (`int`, *optional*, defaults to 100):
The total number of denoising tasks or queries to be used for contrastive denoising.
label_noise_ratio (`float`, *optional*, defaults to 0.5):
The fraction of denoising labels to which random noise should be added.
box_noise_scale (`float`, *optional*, defaults to 1.0):
Scale or magnitude of noise to be added to the bounding boxes.
learn_initial_query (`bool`, *optional*, defaults to `False`):
Indicates whether the initial query embeddings for the decoder should be learned during training
anchor_image_size (`Tuple[int, int]`, *optional*, defaults to `[640, 640]`):
Height and width of the input image used during evaluation to generate the bounding box anchors.
disable_custom_kernels (`bool`, *optional*, defaults to `True`):
Whether to disable custom kernels.
with_box_refine (`bool`, *optional*, defaults to `True`):
Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
based on the predictions from the previous layer.
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
Whether the architecture has an encoder decoder structure.
matcher_alpha (`float`, *optional*, defaults to 0.25):
Parameter alpha used by the Hungarian Matcher.
matcher_gamma (`float`, *optional*, defaults to 2.0):
Parameter gamma used by the Hungarian Matcher.
matcher_class_cost (`float`, *optional*, defaults to 2.0):
The relative weight of the class loss used by the Hungarian Matcher.
matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
The relative weight of the bounding box loss used by the Hungarian Matcher.
matcher_giou_cost (`float`, *optional*, defaults to 2.0):
The relative weight of the giou loss of used by the Hungarian Matcher.
use_focal_loss (`bool`, *optional*, defaults to `True`):
Parameter informing if focal focal should be used.
auxiliary_loss (`bool`, *optional*, defaults to `True`):
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
focal_loss_alpha (`float`, *optional*, defaults to 0.75):
Parameter alpha used to compute the focal loss.
focal_loss_gamma (`float`, *optional*, defaults to 2.0):
Parameter gamma used to compute the focal loss.
weight_loss_vfl (`float`, *optional*, defaults to 1.0):
Relative weight of the varifocal loss in the object detection loss.
weight_loss_bbox (`float`, *optional*, defaults to 5.0):
Relative weight of the L1 bounding box loss in the object detection loss.
weight_loss_giou (`float`, *optional*, defaults to 2.0):
Relative weight of the generalized IoU loss in the object detection loss.
eos_coefficient (`float`, *optional*, defaults to 0.0001):
Relative classification weight of the 'no-object' class in the object detection loss.
Examples:
```python
>>> from transformers import RTDetrConfig, RTDetrModel
>>> # Initializing a RT-DETR configuration
>>> configuration = RTDetrConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = RTDetrModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "rt_detr"
layer_types = ["basic", "bottleneck"]
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "encoder_attention_heads",
}
def __init__(
self,
initializer_range=0.01,
layer_norm_eps=1e-5,
batch_norm_eps=1e-5,
# backbone
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
backbone_kwargs=None,
# encoder HybridEncoder
encoder_hidden_dim=256,
encoder_in_channels=[512, 1024, 2048],
feat_strides=[8, 16, 32],
encoder_layers=1,
encoder_ffn_dim=1024,
encoder_attention_heads=8,
dropout=0.0,
activation_dropout=0.0,
encode_proj_layers=[2],
positional_encoding_temperature=10000,
encoder_activation_function="gelu",
activation_function="silu",
eval_size=None,
normalize_before=False,
hidden_expansion=1.0,
# decoder RTDetrTransformer
d_model=256,
num_queries=300,
decoder_in_channels=[256, 256, 256],
decoder_ffn_dim=1024,
num_feature_levels=3,
decoder_n_points=4,
decoder_layers=6,
decoder_attention_heads=8,
decoder_activation_function="relu",
attention_dropout=0.0,
num_denoising=100,
label_noise_ratio=0.5,
box_noise_scale=1.0,
learn_initial_query=False,
anchor_image_size=[640, 640],
disable_custom_kernels=True,
with_box_refine=True,
is_encoder_decoder=True,
# Loss
matcher_alpha=0.25,
matcher_gamma=2.0,
matcher_class_cost=2.0,
matcher_bbox_cost=5.0,
matcher_giou_cost=2.0,
use_focal_loss=True,
auxiliary_loss=True,
focal_loss_alpha=0.75,
focal_loss_gamma=2.0,
weight_loss_vfl=1.0,
weight_loss_bbox=5.0,
weight_loss_giou=2.0,
eos_coefficient=1e-4,
**kwargs,
):
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.batch_norm_eps = batch_norm_eps
# backbone
if backbone_config is None and backbone is None:
logger.info(
"`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetr-ResNet` backbone."
)
backbone_config = RTDetrResNetConfig(
num_channels=3,
embedding_size=64,
hidden_sizes=[256, 512, 1024, 2048],
depths=[3, 4, 6, 3],
layer_type="bottleneck",
hidden_act="relu",
downsample_in_first_stage=False,
downsample_in_bottleneck=False,
out_features=None,
out_indices=[2, 3, 4],
)
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.pop("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
# encoder
self.encoder_hidden_dim = encoder_hidden_dim
self.encoder_in_channels = encoder_in_channels
self.feat_strides = feat_strides
self.encoder_attention_heads = encoder_attention_heads
self.encoder_ffn_dim = encoder_ffn_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
self.encode_proj_layers = encode_proj_layers
self.encoder_layers = encoder_layers
self.positional_encoding_temperature = positional_encoding_temperature
self.eval_size = eval_size
self.normalize_before = normalize_before
self.encoder_activation_function = encoder_activation_function
self.activation_function = activation_function
self.hidden_expansion = hidden_expansion
# decoder
self.d_model = d_model
self.num_queries = num_queries
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_in_channels = decoder_in_channels
self.num_feature_levels = num_feature_levels
self.decoder_n_points = decoder_n_points
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.decoder_activation_function = decoder_activation_function
self.attention_dropout = attention_dropout
self.num_denoising = num_denoising
self.label_noise_ratio = label_noise_ratio
self.box_noise_scale = box_noise_scale
self.learn_initial_query = learn_initial_query
self.anchor_image_size = anchor_image_size
self.auxiliary_loss = auxiliary_loss
self.disable_custom_kernels = disable_custom_kernels
self.with_box_refine = with_box_refine
# Loss
self.matcher_alpha = matcher_alpha
self.matcher_gamma = matcher_gamma
self.matcher_class_cost = matcher_class_cost
self.matcher_bbox_cost = matcher_bbox_cost
self.matcher_giou_cost = matcher_giou_cost
self.use_focal_loss = use_focal_loss
self.focal_loss_alpha = focal_loss_alpha
self.focal_loss_gamma = focal_loss_gamma
self.weight_loss_vfl = weight_loss_vfl
self.weight_loss_bbox = weight_loss_bbox
self.weight_loss_giou = weight_loss_giou
self.eos_coefficient = eos_coefficient
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
@property
def hidden_size(self) -> int:
return self.d_model
@classmethod
def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
"""Instantiate a [`RTDetrConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
configuration.
Args:
backbone_config ([`PretrainedConfig`]):
The backbone configuration.
Returns:
[`RTDetrConfig`]: An instance of a configuration object
"""
return cls(
backbone_config=backbone_config,
**kwargs,
)

View File

@ -0,0 +1,111 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RT-DETR ResNet model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__)
class RTDetrResNetConfig(BackboneConfigMixin, PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`RTDetrResnetBackbone`]. It is used to instantiate an
ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the ResNet
[microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
embedding_size (`int`, *optional*, defaults to 64):
Dimensionality (hidden size) for the embedding layer.
hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`):
Dimensionality (hidden size) at each stage.
depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`):
Depth (number of layers) for each stage.
layer_type (`str`, *optional*, defaults to `"bottleneck"`):
The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or
`"bottleneck"` (used for larger models like resnet-50 and above).
hidden_act (`str`, *optional*, defaults to `"relu"`):
The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"`
are supported.
downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
If `True`, the first stage will downsample the inputs using a `stride` of 2.
downsample_in_bottleneck (`bool`, *optional*, defaults to `False`):
If `True`, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a `stride` of 2.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
Example:
```python
>>> from transformers import RTDetrResNetConfig, RTDetrResnetBackbone
>>> # Initializing a ResNet resnet-50 style configuration
>>> configuration = RTDetrResNetConfig()
>>> # Initializing a model (with random weights) from the resnet-50 style configuration
>>> model = RTDetrResnetBackbone(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "rt_detr_resnet"
layer_types = ["basic", "bottleneck"]
def __init__(
self,
num_channels=3,
embedding_size=64,
hidden_sizes=[256, 512, 1024, 2048],
depths=[3, 4, 6, 3],
layer_type="bottleneck",
hidden_act="relu",
downsample_in_first_stage=False,
downsample_in_bottleneck=False,
out_features=None,
out_indices=None,
**kwargs,
):
super().__init__(**kwargs)
if layer_type not in self.layer_types:
raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}")
self.num_channels = num_channels
self.embedding_size = embedding_size
self.hidden_sizes = hidden_sizes
self.depths = depths
self.layer_type = layer_type
self.hidden_act = hidden_act
self.downsample_in_first_stage = downsample_in_first_stage
self.downsample_in_bottleneck = downsample_in_bottleneck
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)

View File

@ -0,0 +1,782 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert RT Detr checkpoints with Timm backbone"""
import argparse
import json
from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision import transforms
from transformers import RTDetrConfig, RTDetrForObjectDetection, RTDetrImageProcessor
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_rt_detr_config(model_name: str) -> RTDetrConfig:
config = RTDetrConfig()
config.num_labels = 80
repo_id = "huggingface/label-files"
filename = "coco-detection-mmdet-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
if model_name == "rtdetr_r18vd":
config.backbone_config.hidden_sizes = [64, 128, 256, 512]
config.backbone_config.depths = [2, 2, 2, 2]
config.backbone_config.layer_type = "basic"
config.encoder_in_channels = [128, 256, 512]
config.hidden_expansion = 0.5
config.decoder_layers = 3
elif model_name == "rtdetr_r34vd":
config.backbone_config.hidden_sizes = [64, 128, 256, 512]
config.backbone_config.depths = [3, 4, 6, 3]
config.backbone_config.layer_type = "basic"
config.encoder_in_channels = [128, 256, 512]
config.hidden_expansion = 0.5
config.decoder_layers = 4
elif model_name == "rtdetr_r50vd_m":
pass
elif model_name == "rtdetr_r50vd":
pass
elif model_name == "rtdetr_r101vd":
config.backbone_config.depths = [3, 4, 23, 3]
config.encoder_ffn_dim = 2048
config.encoder_hidden_dim = 384
config.decoder_in_channels = [384, 384, 384]
elif model_name == "rtdetr_r18vd_coco_o365":
config.backbone_config.hidden_sizes = [64, 128, 256, 512]
config.backbone_config.depths = [2, 2, 2, 2]
config.backbone_config.layer_type = "basic"
config.encoder_in_channels = [128, 256, 512]
config.hidden_expansion = 0.5
config.decoder_layers = 3
elif model_name == "rtdetr_r50vd_coco_o365":
pass
elif model_name == "rtdetr_r101vd_coco_o365":
config.backbone_config.depths = [3, 4, 23, 3]
config.encoder_ffn_dim = 2048
config.encoder_hidden_dim = 384
config.decoder_in_channels = [384, 384, 384]
return config
def create_rename_keys(config):
# here we list all keys to be renamed (original name on the left, our name on the right)
rename_keys = []
# stem
# fmt: off
last_key = ["weight", "bias", "running_mean", "running_var"]
for level in range(3):
rename_keys.append((f"backbone.conv1.conv1_{level+1}.conv.weight", f"model.backbone.model.embedder.embedder.{level}.convolution.weight"))
for last in last_key:
rename_keys.append((f"backbone.conv1.conv1_{level+1}.norm.{last}", f"model.backbone.model.embedder.embedder.{level}.normalization.{last}"))
for stage_idx in range(len(config.backbone_config.depths)):
for layer_idx in range(config.backbone_config.depths[stage_idx]):
# shortcut
if layer_idx == 0:
if stage_idx == 0:
rename_keys.append(
(
f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.weight",
f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.convolution.weight",
)
)
for last in last_key:
rename_keys.append(
(
f"backbone.res_layers.{stage_idx}.blocks.0.short.norm.{last}",
f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.normalization.{last}",
)
)
else:
rename_keys.append(
(
f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.conv.weight",
f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.1.convolution.weight",
)
)
for last in last_key:
rename_keys.append(
(
f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.norm.{last}",
f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.1.normalization.{last}",
)
)
rename_keys.append(
(
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2a.conv.weight",
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.convolution.weight",
)
)
for last in last_key:
rename_keys.append((
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2a.norm.{last}",
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.normalization.{last}",
))
rename_keys.append(
(
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2b.conv.weight",
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.convolution.weight",
)
)
for last in last_key:
rename_keys.append((
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2b.norm.{last}",
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.normalization.{last}",
))
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/nn/backbone/presnet.py#L171
if config.backbone_config.layer_type != "basic":
rename_keys.append(
(
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2c.conv.weight",
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.2.convolution.weight",
)
)
for last in last_key:
rename_keys.append((
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2c.norm.{last}",
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.2.normalization.{last}",
))
# fmt: on
for i in range(config.encoder_layers):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append(
(
f"encoder.encoder.{i}.layers.0.self_attn.out_proj.weight",
f"model.encoder.encoder.{i}.layers.0.self_attn.out_proj.weight",
)
)
rename_keys.append(
(
f"encoder.encoder.{i}.layers.0.self_attn.out_proj.bias",
f"model.encoder.encoder.{i}.layers.0.self_attn.out_proj.bias",
)
)
rename_keys.append(
(
f"encoder.encoder.{i}.layers.0.linear1.weight",
f"model.encoder.encoder.{i}.layers.0.fc1.weight",
)
)
rename_keys.append(
(
f"encoder.encoder.{i}.layers.0.linear1.bias",
f"model.encoder.encoder.{i}.layers.0.fc1.bias",
)
)
rename_keys.append(
(
f"encoder.encoder.{i}.layers.0.linear2.weight",
f"model.encoder.encoder.{i}.layers.0.fc2.weight",
)
)
rename_keys.append(
(
f"encoder.encoder.{i}.layers.0.linear2.bias",
f"model.encoder.encoder.{i}.layers.0.fc2.bias",
)
)
rename_keys.append(
(
f"encoder.encoder.{i}.layers.0.norm1.weight",
f"model.encoder.encoder.{i}.layers.0.self_attn_layer_norm.weight",
)
)
rename_keys.append(
(
f"encoder.encoder.{i}.layers.0.norm1.bias",
f"model.encoder.encoder.{i}.layers.0.self_attn_layer_norm.bias",
)
)
rename_keys.append(
(
f"encoder.encoder.{i}.layers.0.norm2.weight",
f"model.encoder.encoder.{i}.layers.0.final_layer_norm.weight",
)
)
rename_keys.append(
(
f"encoder.encoder.{i}.layers.0.norm2.bias",
f"model.encoder.encoder.{i}.layers.0.final_layer_norm.bias",
)
)
for j in range(0, 3):
rename_keys.append((f"encoder.input_proj.{j}.0.weight", f"model.encoder_input_proj.{j}.0.weight"))
for last in last_key:
rename_keys.append((f"encoder.input_proj.{j}.1.{last}", f"model.encoder_input_proj.{j}.1.{last}"))
block_levels = 3 if config.backbone_config.layer_type != "basic" else 4
for i in range(len(config.encoder_in_channels) - 1):
# encoder layers: hybridencoder parts
for j in range(1, block_levels):
rename_keys.append(
(f"encoder.fpn_blocks.{i}.conv{j}.conv.weight", f"model.encoder.fpn_blocks.{i}.conv{j}.conv.weight")
)
for last in last_key:
rename_keys.append(
(
f"encoder.fpn_blocks.{i}.conv{j}.norm.{last}",
f"model.encoder.fpn_blocks.{i}.conv{j}.norm.{last}",
)
)
rename_keys.append((f"encoder.lateral_convs.{i}.conv.weight", f"model.encoder.lateral_convs.{i}.conv.weight"))
for last in last_key:
rename_keys.append(
(f"encoder.lateral_convs.{i}.norm.{last}", f"model.encoder.lateral_convs.{i}.norm.{last}")
)
for j in range(3):
for k in range(1, 3):
rename_keys.append(
(
f"encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight",
f"model.encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight",
)
)
for last in last_key:
rename_keys.append(
(
f"encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}",
f"model.encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}",
)
)
for j in range(1, block_levels):
rename_keys.append(
(f"encoder.pan_blocks.{i}.conv{j}.conv.weight", f"model.encoder.pan_blocks.{i}.conv{j}.conv.weight")
)
for last in last_key:
rename_keys.append(
(
f"encoder.pan_blocks.{i}.conv{j}.norm.{last}",
f"model.encoder.pan_blocks.{i}.conv{j}.norm.{last}",
)
)
for j in range(3):
for k in range(1, 3):
rename_keys.append(
(
f"encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight",
f"model.encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight",
)
)
for last in last_key:
rename_keys.append(
(
f"encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}",
f"model.encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}",
)
)
rename_keys.append(
(f"encoder.downsample_convs.{i}.conv.weight", f"model.encoder.downsample_convs.{i}.conv.weight")
)
for last in last_key:
rename_keys.append(
(f"encoder.downsample_convs.{i}.norm.{last}", f"model.encoder.downsample_convs.{i}.norm.{last}")
)
for i in range(config.decoder_layers):
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
rename_keys.append(
(
f"decoder.decoder.layers.{i}.self_attn.out_proj.weight",
f"model.decoder.layers.{i}.self_attn.out_proj.weight",
)
)
rename_keys.append(
(
f"decoder.decoder.layers.{i}.self_attn.out_proj.bias",
f"model.decoder.layers.{i}.self_attn.out_proj.bias",
)
)
rename_keys.append(
(
f"decoder.decoder.layers.{i}.cross_attn.sampling_offsets.weight",
f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.weight",
)
)
rename_keys.append(
(
f"decoder.decoder.layers.{i}.cross_attn.sampling_offsets.bias",
f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.bias",
)
)
rename_keys.append(
(
f"decoder.decoder.layers.{i}.cross_attn.attention_weights.weight",
f"model.decoder.layers.{i}.encoder_attn.attention_weights.weight",
)
)
rename_keys.append(
(
f"decoder.decoder.layers.{i}.cross_attn.attention_weights.bias",
f"model.decoder.layers.{i}.encoder_attn.attention_weights.bias",
)
)
rename_keys.append(
(
f"decoder.decoder.layers.{i}.cross_attn.value_proj.weight",
f"model.decoder.layers.{i}.encoder_attn.value_proj.weight",
)
)
rename_keys.append(
(
f"decoder.decoder.layers.{i}.cross_attn.value_proj.bias",
f"model.decoder.layers.{i}.encoder_attn.value_proj.bias",
)
)
rename_keys.append(
(
f"decoder.decoder.layers.{i}.cross_attn.output_proj.weight",
f"model.decoder.layers.{i}.encoder_attn.output_proj.weight",
)
)
rename_keys.append(
(
f"decoder.decoder.layers.{i}.cross_attn.output_proj.bias",
f"model.decoder.layers.{i}.encoder_attn.output_proj.bias",
)
)
rename_keys.append(
(f"decoder.decoder.layers.{i}.norm1.weight", f"model.decoder.layers.{i}.self_attn_layer_norm.weight")
)
rename_keys.append(
(f"decoder.decoder.layers.{i}.norm1.bias", f"model.decoder.layers.{i}.self_attn_layer_norm.bias")
)
rename_keys.append(
(f"decoder.decoder.layers.{i}.norm2.weight", f"model.decoder.layers.{i}.encoder_attn_layer_norm.weight")
)
rename_keys.append(
(f"decoder.decoder.layers.{i}.norm2.bias", f"model.decoder.layers.{i}.encoder_attn_layer_norm.bias")
)
rename_keys.append((f"decoder.decoder.layers.{i}.linear1.weight", f"model.decoder.layers.{i}.fc1.weight"))
rename_keys.append((f"decoder.decoder.layers.{i}.linear1.bias", f"model.decoder.layers.{i}.fc1.bias"))
rename_keys.append((f"decoder.decoder.layers.{i}.linear2.weight", f"model.decoder.layers.{i}.fc2.weight"))
rename_keys.append((f"decoder.decoder.layers.{i}.linear2.bias", f"model.decoder.layers.{i}.fc2.bias"))
rename_keys.append(
(f"decoder.decoder.layers.{i}.norm3.weight", f"model.decoder.layers.{i}.final_layer_norm.weight")
)
rename_keys.append(
(f"decoder.decoder.layers.{i}.norm3.bias", f"model.decoder.layers.{i}.final_layer_norm.bias")
)
for i in range(config.decoder_layers):
# decoder + class and bounding box heads
rename_keys.append(
(
f"decoder.dec_score_head.{i}.weight",
f"model.decoder.class_embed.{i}.weight",
)
)
rename_keys.append(
(
f"decoder.dec_score_head.{i}.bias",
f"model.decoder.class_embed.{i}.bias",
)
)
rename_keys.append(
(
f"decoder.dec_bbox_head.{i}.layers.0.weight",
f"model.decoder.bbox_embed.{i}.layers.0.weight",
)
)
rename_keys.append(
(
f"decoder.dec_bbox_head.{i}.layers.0.bias",
f"model.decoder.bbox_embed.{i}.layers.0.bias",
)
)
rename_keys.append(
(
f"decoder.dec_bbox_head.{i}.layers.1.weight",
f"model.decoder.bbox_embed.{i}.layers.1.weight",
)
)
rename_keys.append(
(
f"decoder.dec_bbox_head.{i}.layers.1.bias",
f"model.decoder.bbox_embed.{i}.layers.1.bias",
)
)
rename_keys.append(
(
f"decoder.dec_bbox_head.{i}.layers.2.weight",
f"model.decoder.bbox_embed.{i}.layers.2.weight",
)
)
rename_keys.append(
(
f"decoder.dec_bbox_head.{i}.layers.2.bias",
f"model.decoder.bbox_embed.{i}.layers.2.bias",
)
)
# decoder projection
for i in range(len(config.decoder_in_channels)):
rename_keys.append(
(
f"decoder.input_proj.{i}.conv.weight",
f"model.decoder_input_proj.{i}.0.weight",
)
)
for last in last_key:
rename_keys.append(
(
f"decoder.input_proj.{i}.norm.{last}",
f"model.decoder_input_proj.{i}.1.{last}",
)
)
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
rename_keys.extend(
[
("decoder.denoising_class_embed.weight", "model.denoising_class_embed.weight"),
("decoder.query_pos_head.layers.0.weight", "model.decoder.query_pos_head.layers.0.weight"),
("decoder.query_pos_head.layers.0.bias", "model.decoder.query_pos_head.layers.0.bias"),
("decoder.query_pos_head.layers.1.weight", "model.decoder.query_pos_head.layers.1.weight"),
("decoder.query_pos_head.layers.1.bias", "model.decoder.query_pos_head.layers.1.bias"),
("decoder.enc_output.0.weight", "model.enc_output.0.weight"),
("decoder.enc_output.0.bias", "model.enc_output.0.bias"),
("decoder.enc_output.1.weight", "model.enc_output.1.weight"),
("decoder.enc_output.1.bias", "model.enc_output.1.bias"),
("decoder.enc_score_head.weight", "model.enc_score_head.weight"),
("decoder.enc_score_head.bias", "model.enc_score_head.bias"),
("decoder.enc_bbox_head.layers.0.weight", "model.enc_bbox_head.layers.0.weight"),
("decoder.enc_bbox_head.layers.0.bias", "model.enc_bbox_head.layers.0.bias"),
("decoder.enc_bbox_head.layers.1.weight", "model.enc_bbox_head.layers.1.weight"),
("decoder.enc_bbox_head.layers.1.bias", "model.enc_bbox_head.layers.1.bias"),
("decoder.enc_bbox_head.layers.2.weight", "model.enc_bbox_head.layers.2.weight"),
("decoder.enc_bbox_head.layers.2.bias", "model.enc_bbox_head.layers.2.bias"),
]
)
return rename_keys
def rename_key(state_dict, old, new):
try:
val = state_dict.pop(old)
state_dict[new] = val
except Exception:
pass
def read_in_q_k_v(state_dict, config):
prefix = ""
encoder_hidden_dim = config.encoder_hidden_dim
# first: transformer encoder
for i in range(config.encoder_layers):
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.weight"] = in_proj_weight[
:encoder_hidden_dim, :
]
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.bias"] = in_proj_bias[:encoder_hidden_dim]
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.weight"] = in_proj_weight[
encoder_hidden_dim : 2 * encoder_hidden_dim, :
]
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.bias"] = in_proj_bias[
encoder_hidden_dim : 2 * encoder_hidden_dim
]
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.weight"] = in_proj_weight[
-encoder_hidden_dim:, :
]
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.bias"] = in_proj_bias[-encoder_hidden_dim:]
# next: transformer decoder (which is a bit more complex because it also includes cross-attention)
for i in range(config.decoder_layers):
# read in weights + bias of input projection layer of self-attention
in_proj_weight = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_rt_detr_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, repo_id):
"""
Copy/paste/tweak model's weights to our RTDETR structure.
"""
# load default config
config = get_rt_detr_config(model_name)
# load original model from torch hub
model_name_to_checkpoint_url = {
"rtdetr_r18vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_dec3_6x_coco_from_paddle.pth",
"rtdetr_r34vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r34vd_dec4_6x_coco_from_paddle.pth",
"rtdetr_r50vd_m": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_m_6x_coco_from_paddle.pth",
"rtdetr_r50vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_6x_coco_from_paddle.pth",
"rtdetr_r101vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r101vd_6x_coco_from_paddle.pth",
"rtdetr_r18vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_5x_coco_objects365_from_paddle.pth",
"rtdetr_r50vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_2x_coco_objects365_from_paddle.pth",
"rtdetr_r101vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r101vd_2x_coco_objects365_from_paddle.pth",
}
logger.info(f"Converting model {model_name}...")
state_dict = torch.hub.load_state_dict_from_url(model_name_to_checkpoint_url[model_name], map_location="cpu")[
"ema"
]["module"]
# rename keys
for src, dest in create_rename_keys(config):
rename_key(state_dict, src, dest)
# query, key and value matrices need special treatment
read_in_q_k_v(state_dict, config)
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
for key in state_dict.copy().keys():
if key.endswith("num_batches_tracked"):
del state_dict[key]
# for two_stage
if "bbox_embed" in key or ("class_embed" in key and "denoising_" not in key):
state_dict[key.split("model.decoder.")[-1]] = state_dict[key]
# finally, create HuggingFace model and load state dict
model = RTDetrForObjectDetection(config)
model.load_state_dict(state_dict)
model.eval()
# load image processor
image_processor = RTDetrImageProcessor()
# prepare image
img = prepare_img()
# preprocess image
transformations = transforms.Compose(
[
transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
]
)
original_pixel_values = transformations(img).unsqueeze(0) # insert batch dimension
encoding = image_processor(images=img, return_tensors="pt")
pixel_values = encoding["pixel_values"]
assert torch.allclose(original_pixel_values, pixel_values)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
pixel_values = pixel_values.to(device)
# Pass image by the model
outputs = model(pixel_values)
if model_name == "rtdetr_r18vd":
expected_slice_logits = torch.tensor(
[
[-4.3364253, -6.465683, -3.6130402],
[-4.083815, -6.4039373, -6.97881],
[-4.192215, -7.3410473, -6.9027247],
]
)
expected_slice_boxes = torch.tensor(
[
[0.16868353, 0.19833282, 0.21182671],
[0.25559652, 0.55121744, 0.47988364],
[0.7698693, 0.4124569, 0.46036878],
]
)
elif model_name == "rtdetr_r34vd":
expected_slice_logits = torch.tensor(
[
[-4.3727384, -4.7921476, -5.7299604],
[-4.840536, -8.455345, -4.1745796],
[-4.1277084, -5.2154565, -5.7852697],
]
)
expected_slice_boxes = torch.tensor(
[
[0.258278, 0.5497808, 0.4732004],
[0.16889669, 0.19890057, 0.21138911],
[0.76632994, 0.4147879, 0.46851268],
]
)
elif model_name == "rtdetr_r50vd_m":
expected_slice_logits = torch.tensor(
[
[-4.319764, -6.1349025, -6.094794],
[-5.1056995, -7.744766, -4.803956],
[-4.7685347, -7.9278393, -4.5751696],
]
)
expected_slice_boxes = torch.tensor(
[
[0.2582739, 0.55071366, 0.47660282],
[0.16811174, 0.19954777, 0.21292639],
[0.54986024, 0.2752091, 0.0561416],
]
)
elif model_name == "rtdetr_r50vd":
expected_slice_logits = torch.tensor(
[
[-4.6476398, -5.001154, -4.9785104],
[-4.1593494, -4.7038546, -5.946485],
[-4.4374595, -4.658361, -6.2352347],
]
)
expected_slice_boxes = torch.tensor(
[
[0.16880608, 0.19992264, 0.21225442],
[0.76837635, 0.4122631, 0.46368608],
[0.2595386, 0.5483334, 0.4777486],
]
)
elif model_name == "rtdetr_r101vd":
expected_slice_logits = torch.tensor(
[
[-4.6162, -4.9189, -4.6656],
[-4.4701, -4.4997, -4.9659],
[-5.6641, -7.9000, -5.0725],
]
)
expected_slice_boxes = torch.tensor(
[
[0.7707, 0.4124, 0.4585],
[0.2589, 0.5492, 0.4735],
[0.1688, 0.1993, 0.2108],
]
)
elif model_name == "rtdetr_r18vd_coco_o365":
expected_slice_logits = torch.tensor(
[
[-4.8726, -5.9066, -5.2450],
[-4.8157, -6.8764, -5.1656],
[-4.7492, -5.7006, -5.1333],
]
)
expected_slice_boxes = torch.tensor(
[
[0.2552, 0.5501, 0.4773],
[0.1685, 0.1986, 0.2104],
[0.7692, 0.4141, 0.4620],
]
)
elif model_name == "rtdetr_r50vd_coco_o365":
expected_slice_logits = torch.tensor(
[
[-4.6491, -3.9252, -5.3163],
[-4.1386, -5.0348, -3.9016],
[-4.4778, -4.5423, -5.7356],
]
)
expected_slice_boxes = torch.tensor(
[
[0.2583, 0.5492, 0.4747],
[0.5501, 0.2754, 0.0574],
[0.7693, 0.4137, 0.4613],
]
)
elif model_name == "rtdetr_r101vd_coco_o365":
expected_slice_logits = torch.tensor(
[
[-4.5152, -5.6811, -5.7311],
[-4.5358, -7.2422, -5.0941],
[-4.6919, -5.5834, -6.0145],
]
)
expected_slice_boxes = torch.tensor(
[
[0.7703, 0.4140, 0.4583],
[0.1686, 0.1991, 0.2107],
[0.2570, 0.5496, 0.4750],
]
)
else:
raise ValueError(f"Unknown rt_detr_name: {model_name}")
assert torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits.to(outputs.logits.device), atol=1e-4)
assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes.to(outputs.pred_boxes.device), atol=1e-3)
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving image processor to {pytorch_dump_folder_path}")
image_processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
# Upload model, image processor and config to the hub
logger.info("Uploading PyTorch model and image processor to the hub...")
config.push_to_hub(
repo_id=repo_id, commit_message="Add config from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py"
)
model.push_to_hub(
repo_id=repo_id, commit_message="Add model from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py"
)
image_processor.push_to_hub(
repo_id=repo_id,
commit_message="Add image processor from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default="rtdetr_r50vd",
type=str,
help="model_name of the checkpoint you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to the hub or not.")
parser.add_argument(
"--repo_id",
type=str,
help="repo_id where the model will be pushed to.",
)
args = parser.parse_args()
convert_rt_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.repo_id)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,426 @@
# coding=utf-8
# Copyright 2024 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PyTorch RTDetr specific ResNet model. The main difference between hugginface ResNet model is that this RTDetrResNet model forces to use shortcut at the first layer in the resnet-18/34 models.
See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L126 for details.
"""
from typing import Optional
from torch import Tensor, nn
from ...activations import ACT2FN
from ...modeling_outputs import (
BackboneOutput,
BaseModelOutputWithNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin
from .configuration_rt_detr_resnet import RTDetrResNetConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "RTDetrResNetConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "microsoft/resnet-50"
_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7]
# Copied from transformers.models.resnet.modeling_resnet.ResNetConvLayer -> RTDetrResNetConvLayer
class RTDetrResNetConvLayer(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
):
super().__init__()
self.convolution = nn.Conv2d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False
)
self.normalization = nn.BatchNorm2d(out_channels)
self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
def forward(self, input: Tensor) -> Tensor:
hidden_state = self.convolution(input)
hidden_state = self.normalization(hidden_state)
hidden_state = self.activation(hidden_state)
return hidden_state
class RTDetrResNetEmbeddings(nn.Module):
"""
ResNet Embeddings (stem) composed of a deep aggressive convolution.
"""
def __init__(self, config: RTDetrResNetConfig):
super().__init__()
self.embedder = nn.Sequential(
*[
RTDetrResNetConvLayer(
config.num_channels,
config.embedding_size // 2,
kernel_size=3,
stride=2,
activation=config.hidden_act,
),
RTDetrResNetConvLayer(
config.embedding_size // 2,
config.embedding_size // 2,
kernel_size=3,
stride=1,
activation=config.hidden_act,
),
RTDetrResNetConvLayer(
config.embedding_size // 2,
config.embedding_size,
kernel_size=3,
stride=1,
activation=config.hidden_act,
),
]
)
self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.num_channels = config.num_channels
def forward(self, pixel_values: Tensor) -> Tensor:
num_channels = pixel_values.shape[1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embedding = self.embedder(pixel_values)
embedding = self.pooler(embedding)
return embedding
# Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut -> RTDetrResNetChortCut
class RTDetrResNetShortCut(nn.Module):
"""
ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
"""
def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
super().__init__()
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.normalization = nn.BatchNorm2d(out_channels)
def forward(self, input: Tensor) -> Tensor:
hidden_state = self.convolution(input)
hidden_state = self.normalization(hidden_state)
return hidden_state
class RTDetrResNetBasicLayer(nn.Module):
"""
A classic ResNet's residual layer composed by two `3x3` convolutions.
See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L34.
"""
def __init__(
self,
config: RTDetrResNetConfig,
in_channels: int,
out_channels: int,
stride: int = 1,
should_apply_shortcut: bool = False,
):
super().__init__()
if in_channels != out_channels:
self.shortcut = (
nn.Sequential(
*[nn.AvgPool2d(2, 2, 0, ceil_mode=True), RTDetrResNetShortCut(in_channels, out_channels, stride=1)]
)
if should_apply_shortcut
else nn.Identity()
)
else:
self.shortcut = (
RTDetrResNetShortCut(in_channels, out_channels, stride=stride)
if should_apply_shortcut
else nn.Identity()
)
self.layer = nn.Sequential(
RTDetrResNetConvLayer(in_channels, out_channels, stride=stride),
RTDetrResNetConvLayer(out_channels, out_channels, activation=None),
)
self.activation = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
residual = hidden_state
hidden_state = self.layer(hidden_state)
residual = self.shortcut(residual)
hidden_state += residual
hidden_state = self.activation(hidden_state)
return hidden_state
class RTDetrResNetBottleNeckLayer(nn.Module):
"""
A classic RTDetrResNet's bottleneck layer composed by three `3x3` convolutions.
The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If
`downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer.
"""
def __init__(
self,
config: RTDetrResNetConfig,
in_channels: int,
out_channels: int,
stride: int = 1,
):
super().__init__()
reduction = 4
should_apply_shortcut = in_channels != out_channels or stride != 1
reduces_channels = out_channels // reduction
if stride == 2:
self.shortcut = nn.Sequential(
*[
nn.AvgPool2d(2, 2, 0, ceil_mode=True),
RTDetrResNetShortCut(in_channels, out_channels, stride=1)
if should_apply_shortcut
else nn.Identity(),
]
)
else:
self.shortcut = (
RTDetrResNetShortCut(in_channels, out_channels, stride=stride)
if should_apply_shortcut
else nn.Identity()
)
self.layer = nn.Sequential(
RTDetrResNetConvLayer(
in_channels, reduces_channels, kernel_size=1, stride=stride if config.downsample_in_bottleneck else 1
),
RTDetrResNetConvLayer(
reduces_channels, reduces_channels, stride=stride if not config.downsample_in_bottleneck else 1
),
RTDetrResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None),
)
self.activation = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
residual = hidden_state
hidden_state = self.layer(hidden_state)
residual = self.shortcut(residual)
hidden_state += residual
hidden_state = self.activation(hidden_state)
return hidden_state
class RTDetrResNetStage(nn.Module):
"""
A RTDetrResNet stage composed by stacked layers.
"""
def __init__(
self,
config: RTDetrResNetConfig,
in_channels: int,
out_channels: int,
stride: int = 2,
depth: int = 2,
):
super().__init__()
layer = RTDetrResNetBottleNeckLayer if config.layer_type == "bottleneck" else RTDetrResNetBasicLayer
if config.layer_type == "bottleneck":
first_layer = layer(
config,
in_channels,
out_channels,
stride=stride,
)
else:
first_layer = layer(config, in_channels, out_channels, stride=stride, should_apply_shortcut=True)
self.layers = nn.Sequential(
first_layer, *[layer(config, out_channels, out_channels) for _ in range(depth - 1)]
)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
# Copied from transformers.models.resnet.modeling_resnet.ResNetEncoder with ResNet->RTDetrResNet
class RTDetrResNetEncoder(nn.Module):
def __init__(self, config: RTDetrResNetConfig):
super().__init__()
self.stages = nn.ModuleList([])
# based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
self.stages.append(
RTDetrResNetStage(
config,
config.embedding_size,
config.hidden_sizes[0],
stride=2 if config.downsample_in_first_stage else 1,
depth=config.depths[0],
)
)
in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
self.stages.append(RTDetrResNetStage(config, in_channels, out_channels, depth=depth))
def forward(
self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
) -> BaseModelOutputWithNoAttention:
hidden_states = () if output_hidden_states else None
for stage_module in self.stages:
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
hidden_state = stage_module(hidden_state)
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
if not return_dict:
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
return BaseModelOutputWithNoAttention(
last_hidden_state=hidden_state,
hidden_states=hidden_states,
)
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RTDetrResNet
class RTDetrResNetPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = RTDetrResNetConfig
base_model_prefix = "resnet"
main_input_name = "pixel_values"
_no_split_modules = ["RTDetrResNetConvLayer", "RTDetrResNetShortCut"]
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
RTDETR_RESNET_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`RTDetrResNetConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
RTDETR_RESNET_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`RTDetrImageProcessor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"""
ResNet backbone, to be used with frameworks like RTDETR.
""",
RTDETR_RESNET_START_DOCSTRING,
)
class RTDetrResNetBackbone(RTDetrResNetPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
self.num_features = [config.embedding_size] + config.hidden_sizes
self.embedder = RTDetrResNetEmbeddings(config)
self.encoder = RTDetrResNetEncoder(config)
# initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(RTDETR_RESNET_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
) -> BackboneOutput:
"""
Returns:
Examples:
```python
>>> from transformers import RTDetrResNetConfig, RTDetrResNetBackbone
>>> import torch
>>> config = RTDetrResNetConfig()
>>> model = RTDetrResNetBackbone(config)
>>> pixel_values = torch.randn(1, 3, 224, 224)
>>> with torch.no_grad():
... outputs = model(pixel_values)
>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 2048, 7, 7]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
embedding_output = self.embedder(pixel_values)
outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
hidden_states = outputs.hidden_states
feature_maps = ()
for idx, stage in enumerate(self.stage_names):
if stage in self.out_features:
feature_maps += (hidden_states[idx],)
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=None,
)

View File

@ -113,10 +113,10 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
return super()._from_config(config, **kwargs)
def freeze_batch_norm_2d(self):
timm.layers.freeze_batch_norm_2d(self._backbone)
timm.utils.model.freeze_batch_norm_2d(self._backbone)
def unfreeze_batch_norm_2d(self):
timm.layers.unfreeze_batch_norm_2d(self._backbone)
timm.utils.model.unfreeze_batch_norm_2d(self._backbone)
def _init_weights(self, module):
"""

View File

@ -313,7 +313,6 @@ def load_backbone(config):
use_pretrained_backbone = getattr(config, "use_pretrained_backbone", None)
backbone_checkpoint = getattr(config, "backbone", None)
backbone_kwargs = getattr(config, "backbone_kwargs", None)
backbone_kwargs = {} if backbone_kwargs is None else backbone_kwargs
if backbone_kwargs and backbone_config is not None:

View File

@ -7492,6 +7492,41 @@ def load_tf_weights_in_roformer(*args, **kwargs):
requires_backends(load_tf_weights_in_roformer, ["torch"])
class RTDetrForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class RTDetrModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class RTDetrPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class RTDetrResNetBackbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class RTDetrResNetPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class RwkvForCausalLM(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -492,6 +492,13 @@ class PvtImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class RTDetrImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class SamImageProcessor(metaclass=DummyObject):
_backends = ["vision"]

View File

View File

@ -0,0 +1,364 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import unittest
import requests
from transformers.testing_utils import require_torch, require_vision, slow
from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_vision_available():
from PIL import Image
from transformers import RTDetrImageProcessor
if is_torch_available():
import torch
class RTDetrImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=4,
num_channels=3,
do_resize=True,
size=None,
do_rescale=True,
rescale_factor=1 / 255,
do_normalize=False,
do_pad=False,
return_tensors="pt",
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.do_resize = do_resize
self.size = size if size is not None else {"height": 640, "width": 640}
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.do_pad = do_pad
self.return_tensors = return_tensors
def prepare_image_processor_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"do_rescale": self.do_rescale,
"rescale_factor": self.rescale_factor,
"do_normalize": self.do_normalize,
"do_pad": self.do_pad,
"return_tensors": self.return_tensors,
}
def get_expected_values(self):
return self.size["height"], self.size["width"]
def expected_output_image_shape(self, images):
height, width = self.get_expected_values()
return self.num_channels, height, width
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=30,
max_resolution=400,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
class RtDetrImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = RTDetrImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = RTDetrImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "return_tensors"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 640, "width": 640})
def test_valid_coco_detection_annotations(self):
# prepare image and target
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f:
target = json.loads(f.read())
params = {"image_id": 39769, "annotations": target}
# encode them
image_processing = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
# legal encodings (single image)
_ = image_processing(images=image, annotations=params, return_tensors="pt")
_ = image_processing(images=image, annotations=[params], return_tensors="pt")
# legal encodings (batch of one image)
_ = image_processing(images=[image], annotations=params, return_tensors="pt")
_ = image_processing(images=[image], annotations=[params], return_tensors="pt")
# legal encoding (batch of more than one image)
n = 5
_ = image_processing(images=[image] * n, annotations=[params] * n, return_tensors="pt")
# example of an illegal encoding (missing the 'image_id' key)
with self.assertRaises(ValueError) as e:
image_processing(images=image, annotations={"annotations": target}, return_tensors="pt")
self.assertTrue(str(e.exception).startswith("Invalid COCO detection annotations"))
# example of an illegal encoding (unequal lengths of images and annotations)
with self.assertRaises(ValueError) as e:
image_processing(images=[image] * n, annotations=[params] * (n - 1), return_tensors="pt")
self.assertTrue(str(e.exception) == "The number of images (5) and annotations (4) do not match.")
@slow
def test_call_pytorch_with_coco_detection_annotations(self):
# prepare image and target
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f:
target = json.loads(f.read())
target = {"image_id": 39769, "annotations": target}
# encode them
image_processing = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
encoding = image_processing(images=image, annotations=target, return_tensors="pt")
# verify pixel values
expected_shape = torch.Size([1, 3, 640, 640])
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
expected_slice = torch.tensor([0.5490, 0.5647, 0.5725])
self.assertTrue(torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-4))
# verify area
expected_area = torch.tensor([2827.9883, 5403.4761, 235036.7344, 402070.2188, 71068.8281, 79601.2812])
self.assertTrue(torch.allclose(encoding["labels"][0]["area"], expected_area))
# verify boxes
expected_boxes_shape = torch.Size([6, 4])
self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
expected_boxes_slice = torch.tensor([0.5503, 0.2765, 0.0604, 0.2215])
self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"][0], expected_boxes_slice, atol=1e-3))
# verify image_id
expected_image_id = torch.tensor([39769])
self.assertTrue(torch.allclose(encoding["labels"][0]["image_id"], expected_image_id))
# verify is_crowd
expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
self.assertTrue(torch.allclose(encoding["labels"][0]["iscrowd"], expected_is_crowd))
# verify class_labels
expected_class_labels = torch.tensor([75, 75, 63, 65, 17, 17])
self.assertTrue(torch.allclose(encoding["labels"][0]["class_labels"], expected_class_labels))
# verify orig_size
expected_orig_size = torch.tensor([480, 640])
self.assertTrue(torch.allclose(encoding["labels"][0]["orig_size"], expected_orig_size))
# verify size
expected_size = torch.tensor([640, 640])
self.assertTrue(torch.allclose(encoding["labels"][0]["size"], expected_size))
@slow
def test_image_processor_outputs(self):
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
image_processing = self.image_processing_class(**self.image_processor_dict)
encoding = image_processing(images=image, return_tensors="pt")
# verify pixel values: shape
expected_shape = torch.Size([1, 3, 640, 640])
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
# verify pixel values: output values
expected_slice = torch.tensor([0.5490196347236633, 0.5647059082984924, 0.572549045085907])
self.assertTrue(torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-5))
def test_multiple_images_processor_outputs(self):
images_urls = [
"http://images.cocodataset.org/val2017/000000000139.jpg",
"http://images.cocodataset.org/val2017/000000000285.jpg",
"http://images.cocodataset.org/val2017/000000000632.jpg",
"http://images.cocodataset.org/val2017/000000000724.jpg",
"http://images.cocodataset.org/val2017/000000000776.jpg",
"http://images.cocodataset.org/val2017/000000000785.jpg",
"http://images.cocodataset.org/val2017/000000000802.jpg",
"http://images.cocodataset.org/val2017/000000000872.jpg",
]
images = []
for url in images_urls:
image = Image.open(requests.get(url, stream=True).raw)
images.append(image)
# apply image processing
image_processing = self.image_processing_class(**self.image_processor_dict)
encoding = image_processing(images=images, return_tensors="pt")
# verify if pixel_values is part of the encoding
self.assertIn("pixel_values", encoding)
# verify pixel values: shape
expected_shape = torch.Size([8, 3, 640, 640])
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
# verify pixel values: output values
expected_slices = torch.tensor(
[
[0.5333333611488342, 0.5568627715110779, 0.5647059082984924],
[0.5372549295425415, 0.4705882668495178, 0.4274510145187378],
[0.3960784673690796, 0.35686275362968445, 0.3686274588108063],
[0.20784315466880798, 0.1882353127002716, 0.15294118225574493],
[0.364705890417099, 0.364705890417099, 0.3686274588108063],
[0.8078432083129883, 0.8078432083129883, 0.8078432083129883],
[0.4431372880935669, 0.4431372880935669, 0.4431372880935669],
[0.19607844948768616, 0.21176472306251526, 0.3607843220233917],
]
)
self.assertTrue(torch.allclose(encoding["pixel_values"][:, 1, 0, :3], expected_slices, atol=1e-5))
@slow
def test_batched_coco_detection_annotations(self):
image_0 = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
image_1 = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png").resize((800, 800))
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f:
target = json.loads(f.read())
annotations_0 = {"image_id": 39769, "annotations": target}
annotations_1 = {"image_id": 39769, "annotations": target}
# Adjust the bounding boxes for the resized image
w_0, h_0 = image_0.size
w_1, h_1 = image_1.size
for i in range(len(annotations_1["annotations"])):
coords = annotations_1["annotations"][i]["bbox"]
new_bbox = [
coords[0] * w_1 / w_0,
coords[1] * h_1 / h_0,
coords[2] * w_1 / w_0,
coords[3] * h_1 / h_0,
]
annotations_1["annotations"][i]["bbox"] = new_bbox
images = [image_0, image_1]
annotations = [annotations_0, annotations_1]
image_processing = RTDetrImageProcessor()
encoding = image_processing(
images=images,
annotations=annotations,
return_segmentation_masks=True,
return_tensors="pt", # do_convert_annotations=True
)
# Check the pixel values have been padded
postprocessed_height, postprocessed_width = 640, 640
expected_shape = torch.Size([2, 3, postprocessed_height, postprocessed_width])
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
# Check the bounding boxes have been adjusted for padded images
self.assertEqual(encoding["labels"][0]["boxes"].shape, torch.Size([6, 4]))
self.assertEqual(encoding["labels"][1]["boxes"].shape, torch.Size([6, 4]))
expected_boxes_0 = torch.tensor(
[
[0.6879, 0.4609, 0.0755, 0.3691],
[0.2118, 0.3359, 0.2601, 0.1566],
[0.5011, 0.5000, 0.9979, 1.0000],
[0.5010, 0.5020, 0.9979, 0.9959],
[0.3284, 0.5944, 0.5884, 0.8112],
[0.8394, 0.5445, 0.3213, 0.9110],
]
)
expected_boxes_1 = torch.tensor(
[
[0.5503, 0.2765, 0.0604, 0.2215],
[0.1695, 0.2016, 0.2080, 0.0940],
[0.5006, 0.4933, 0.9977, 0.9865],
[0.5008, 0.5002, 0.9983, 0.9955],
[0.2627, 0.5456, 0.4707, 0.8646],
[0.7715, 0.4115, 0.4570, 0.7161],
]
)
self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"], expected_boxes_0, rtol=1e-3))
self.assertTrue(torch.allclose(encoding["labels"][1]["boxes"], expected_boxes_1, rtol=1e-3))
# Check if do_convert_annotations=False, then the annotations are not converted to centre_x, centre_y, width, height
# format and not in the range [0, 1]
encoding = image_processing(
images=images,
annotations=annotations,
return_segmentation_masks=True,
do_convert_annotations=False,
return_tensors="pt",
)
self.assertEqual(encoding["labels"][0]["boxes"].shape, torch.Size([6, 4]))
self.assertEqual(encoding["labels"][1]["boxes"].shape, torch.Size([6, 4]))
# Convert to absolute coordinates
unnormalized_boxes_0 = torch.vstack(
[
expected_boxes_0[:, 0] * postprocessed_width,
expected_boxes_0[:, 1] * postprocessed_height,
expected_boxes_0[:, 2] * postprocessed_width,
expected_boxes_0[:, 3] * postprocessed_height,
]
).T
unnormalized_boxes_1 = torch.vstack(
[
expected_boxes_1[:, 0] * postprocessed_width,
expected_boxes_1[:, 1] * postprocessed_height,
expected_boxes_1[:, 2] * postprocessed_width,
expected_boxes_1[:, 3] * postprocessed_height,
]
).T
# Convert from centre_x, centre_y, width, height to x_min, y_min, x_max, y_max
expected_boxes_0 = torch.vstack(
[
unnormalized_boxes_0[:, 0] - unnormalized_boxes_0[:, 2] / 2,
unnormalized_boxes_0[:, 1] - unnormalized_boxes_0[:, 3] / 2,
unnormalized_boxes_0[:, 0] + unnormalized_boxes_0[:, 2] / 2,
unnormalized_boxes_0[:, 1] + unnormalized_boxes_0[:, 3] / 2,
]
).T
expected_boxes_1 = torch.vstack(
[
unnormalized_boxes_1[:, 0] - unnormalized_boxes_1[:, 2] / 2,
unnormalized_boxes_1[:, 1] - unnormalized_boxes_1[:, 3] / 2,
unnormalized_boxes_1[:, 0] + unnormalized_boxes_1[:, 2] / 2,
unnormalized_boxes_1[:, 1] + unnormalized_boxes_1[:, 3] / 2,
]
).T
self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"], expected_boxes_0, rtol=1))
self.assertTrue(torch.allclose(encoding["labels"][1]["boxes"], expected_boxes_1, rtol=1))

View File

@ -0,0 +1,680 @@
# coding = utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch RT_DETR model."""
import inspect
import math
import unittest
from transformers import (
RTDetrConfig,
RTDetrImageProcessor,
RTDetrResNetConfig,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_torch, require_vision, torch_device
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import RTDetrForObjectDetection, RTDetrModel
if is_vision_available():
from PIL import Image
CHECKPOINT = "PekingU/rtdetr_r50vd" # TODO: replace
class RTDetrModelTester:
def __init__(
self,
parent,
batch_size=3,
is_training=True,
use_labels=True,
n_targets=3,
num_labels=10,
initializer_range=0.02,
layer_norm_eps=1e-5,
batch_norm_eps=1e-5,
# backbone
backbone_config=None,
# encoder HybridEncoder
encoder_hidden_dim=32,
encoder_in_channels=[128, 256, 512],
feat_strides=[8, 16, 32],
encoder_layers=1,
encoder_ffn_dim=64,
encoder_attention_heads=2,
dropout=0.0,
activation_dropout=0.0,
encode_proj_layers=[2],
positional_encoding_temperature=10000,
encoder_activation_function="gelu",
activation_function="silu",
eval_size=None,
normalize_before=False,
# decoder RTDetrTransformer
d_model=32,
num_queries=30,
decoder_in_channels=[32, 32, 32],
decoder_ffn_dim=64,
num_feature_levels=3,
decoder_n_points=4,
decoder_layers=2,
decoder_attention_heads=2,
decoder_activation_function="relu",
attention_dropout=0.0,
num_denoising=0,
label_noise_ratio=0.5,
box_noise_scale=1.0,
learn_initial_query=False,
anchor_image_size=[64, 64],
image_size=64,
disable_custom_kernels=True,
with_box_refine=True,
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = 3
self.is_training = is_training
self.use_labels = use_labels
self.n_targets = n_targets
self.num_labels = num_labels
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.batch_norm_eps = batch_norm_eps
self.backbone_config = backbone_config
self.encoder_hidden_dim = encoder_hidden_dim
self.encoder_in_channels = encoder_in_channels
self.feat_strides = feat_strides
self.encoder_layers = encoder_layers
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_attention_heads = encoder_attention_heads
self.dropout = dropout
self.activation_dropout = activation_dropout
self.encode_proj_layers = encode_proj_layers
self.positional_encoding_temperature = positional_encoding_temperature
self.encoder_activation_function = encoder_activation_function
self.activation_function = activation_function
self.eval_size = eval_size
self.normalize_before = normalize_before
self.d_model = d_model
self.num_queries = num_queries
self.decoder_in_channels = decoder_in_channels
self.decoder_ffn_dim = decoder_ffn_dim
self.num_feature_levels = num_feature_levels
self.decoder_n_points = decoder_n_points
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.decoder_activation_function = decoder_activation_function
self.attention_dropout = attention_dropout
self.num_denoising = num_denoising
self.label_noise_ratio = label_noise_ratio
self.box_noise_scale = box_noise_scale
self.learn_initial_query = learn_initial_query
self.anchor_image_size = anchor_image_size
self.image_size = image_size
self.disable_custom_kernels = disable_custom_kernels
self.with_box_refine = with_box_refine
self.encoder_seq_length = math.ceil(self.image_size / 32) * math.ceil(self.image_size / 32)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device)
labels = None
if self.use_labels:
# labels is a list of Dict (each Dict being the labels for a given example in the batch)
labels = []
for i in range(self.batch_size):
target = {}
target["class_labels"] = torch.randint(
high=self.num_labels, size=(self.n_targets,), device=torch_device
)
target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device)
labels.append(target)
config = self.get_config()
config.num_labels = self.num_labels
return config, pixel_values, pixel_mask, labels
def get_config(self):
hidden_sizes = [10, 20, 30, 40]
backbone_config = RTDetrResNetConfig(
embeddings_size=10,
hidden_sizes=hidden_sizes,
depths=[1, 1, 2, 1],
out_features=["stage2", "stage3", "stage4"],
out_indices=[2, 3, 4],
)
return RTDetrConfig.from_backbone_configs(
backbone_config=backbone_config,
encoder_hidden_dim=self.encoder_hidden_dim,
encoder_in_channels=hidden_sizes[1:],
feat_strides=self.feat_strides,
encoder_layers=self.encoder_layers,
encoder_ffn_dim=self.encoder_ffn_dim,
encoder_attention_heads=self.encoder_attention_heads,
dropout=self.dropout,
activation_dropout=self.activation_dropout,
encode_proj_layers=self.encode_proj_layers,
positional_encoding_temperature=self.positional_encoding_temperature,
encoder_activation_function=self.encoder_activation_function,
activation_function=self.activation_function,
eval_size=self.eval_size,
normalize_before=self.normalize_before,
d_model=self.d_model,
num_queries=self.num_queries,
decoder_in_channels=self.decoder_in_channels,
decoder_ffn_dim=self.decoder_ffn_dim,
num_feature_levels=self.num_feature_levels,
decoder_n_points=self.decoder_n_points,
decoder_layers=self.decoder_layers,
decoder_attention_heads=self.decoder_attention_heads,
decoder_activation_function=self.decoder_activation_function,
attention_dropout=self.attention_dropout,
num_denoising=self.num_denoising,
label_noise_ratio=self.label_noise_ratio,
box_noise_scale=self.box_noise_scale,
learn_initial_query=self.learn_initial_query,
anchor_image_size=self.anchor_image_size,
image_size=self.image_size,
disable_custom_kernels=self.disable_custom_kernels,
with_box_refine=self.with_box_refine,
)
def prepare_config_and_inputs_for_common(self):
config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs()
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
def create_and_check_rt_detr_model(self, config, pixel_values, pixel_mask, labels):
model = RTDetrModel(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model))
def create_and_check_rt_detr_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
model = RTDetrForObjectDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
@require_torch
class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (RTDetrModel, RTDetrForObjectDetection) if is_torch_available() else ()
pipeline_model_mapping = (
{"image-feature-extraction": RTDetrModel, "object-detection": RTDetrForObjectDetection}
if is_torch_available()
else {}
)
is_encoder_decoder = True
test_torchscript = False
test_pruning = False
test_head_masking = False
test_missing_keys = False
# special case for head models
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class.__name__ == "RTDetrForObjectDetection":
labels = []
for i in range(self.model_tester.batch_size):
target = {}
target["class_labels"] = torch.ones(
size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long
)
target["boxes"] = torch.ones(
self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float
)
labels.append(target)
inputs_dict["labels"] = labels
return inputs_dict
def setUp(self):
self.model_tester = RTDetrModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=RTDetrConfig,
has_text_modality=False,
common_properties=["hidden_size", "num_attention_heads"],
)
def test_config(self):
self.config_tester.run_common_tests()
def test_rt_detr_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_rt_detr_model(*config_and_inputs)
def test_rt_detr_object_detection_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_rt_detr_object_detection_head_model(*config_and_inputs)
@unittest.skip(reason="RTDetr does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="RTDetr does not use test_inputs_embeds_matches_input_ids")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="RTDetr does not support input and output embeddings")
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="RTDetr does not support input and output embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="RTDetr does not use token embeddings")
def test_resize_tokens_embeddings(self):
pass
@unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self):
pass
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions
self.assertEqual(len(attentions), self.model_tester.encoder_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions
self.assertEqual(len(attentions), self.model_tester.encoder_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[
self.model_tester.encoder_attention_heads,
self.model_tester.encoder_seq_length,
self.model_tester.encoder_seq_length,
],
)
out_len = len(outputs)
correct_outlen = 13
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
# Object Detection model returns pred_logits and pred_boxes
if model_class.__name__ == "RTDetrForObjectDetection":
correct_outlen += 2
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.decoder_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[
self.model_tester.decoder_attention_heads,
self.model_tester.num_queries,
self.model_tester.num_queries,
],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.decoder_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.decoder_attention_heads,
self.model_tester.num_feature_levels,
self.model_tester.decoder_n_points,
],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
else:
# RTDetr should maintin encoder_hidden_states output
added_hidden_states = 2
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions
self.assertEqual(len(self_attentions), self.model_tester.encoder_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[
self.model_tester.encoder_attention_heads,
self.model_tester.encoder_seq_length,
self.model_tester.encoder_seq_length,
],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.encoder_in_channels) - 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual(
list(hidden_states[1].shape[-2:]),
[
self.model_tester.image_size // self.model_tester.feat_strides[-1],
self.model_tester.image_size // self.model_tester.feat_strides[-1],
],
)
if config.is_encoder_decoder:
hidden_states = outputs.decoder_hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.decoder_layers + 1
)
self.assertIsInstance(hidden_states, (list, tuple))
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.num_queries, self.model_tester.d_model],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**inputs)
# we take the first output since last_hidden_state is the first item
output = outputs[0]
encoder_hidden_states = outputs.encoder_hidden_states[0]
encoder_attentions = outputs.encoder_attentions[0]
encoder_hidden_states.retain_grad()
encoder_attentions.retain_grad()
decoder_attentions = outputs.decoder_attentions[0]
decoder_attentions.retain_grad()
cross_attentions = outputs.cross_attentions[0]
cross_attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(encoder_hidden_states.grad)
self.assertIsNotNone(encoder_attentions.grad)
self.assertIsNotNone(decoder_attentions.grad)
self.assertIsNotNone(cross_attentions.grad)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_different_timm_backbone(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.backbone_config = None
config.use_timm_backbone = True
config.backbone_kwargs = {"out_indices": [2, 3, 4]}
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if model_class.__name__ == "RTDetrForObjectDetection":
expected_shape = (
self.model_tester.batch_size,
self.model_tester.num_queries,
self.model_tester.num_labels,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.intermediate_channel_sizes), 3)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.intermediate_channel_sizes), 3)
self.assertTrue(outputs)
def test_hf_backbone(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Load a pretrained HF checkpoint as backbone
config.backbone = "microsoft/resnet-18"
config.backbone_config = None
config.use_timm_backbone = False
config.use_pretrained_backbone = True
config.backbone_kwargs = {"out_indices": [2, 3, 4]}
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if model_class.__name__ == "RTDetrForObjectDetection":
expected_shape = (
self.model_tester.batch_size,
self.model_tester.num_queries,
self.model_tester.num_labels,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.intermediate_channel_sizes), 3)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.intermediate_channel_sizes), 3)
self.assertTrue(outputs)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Skip the check for the backbone
for name, module in model.named_modules():
if module.__class__.__name__ == "RTDetrConvEncoder":
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
break
for name, param in model.named_parameters():
if param.requires_grad:
if (
"level_embed" in name
or "sampling_offsets.bias" in name
or "value_proj" in name
or "output_proj" in name
or "reference_points" in name
or name in backbone_params
):
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
TOLERANCE = 1e-4
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_torch
@require_vision
class RTDetrModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return RTDetrImageProcessor.from_pretrained(CHECKPOINT) if is_vision_available() else None
def test_inference_object_detection_head(self):
model = RTDetrForObjectDetection.from_pretrained(CHECKPOINT).to(torch_device)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
expected_shape_logits = torch.Size((1, 300, model.config.num_labels))
self.assertEqual(outputs.logits.shape, expected_shape_logits)
expected_logits = torch.tensor(
[
[-4.64763879776001, -5.001153945922852, -4.978509902954102],
[-4.159348487854004, -4.703853607177734, -5.946484565734863],
[-4.437461853027344, -4.65836238861084, -6.235235691070557],
]
).to(torch_device)
expected_boxes = torch.tensor(
[
[0.1688060760498047, 0.19992263615131378, 0.21225441992282867],
[0.768376350402832, 0.41226309537887573, 0.4636859893798828],
[0.25953856110572815, 0.5483334064483643, 0.4777486026287079],
]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_logits, atol=1e-4))
expected_shape_boxes = torch.Size((1, 300, 4))
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=1e-4))
# verify postprocessing
results = image_processor.post_process_object_detection(
outputs, threshold=0.0, target_sizes=[image.size[::-1]]
)[0]
expected_scores = torch.tensor(
[0.9703017473220825, 0.9599503874778748, 0.9575679302215576, 0.9506784677505493], device=torch_device
)
expected_labels = [57, 15, 15, 65]
expected_slice_boxes = torch.tensor(
[
[0.13774872, 0.37821293, 640.13074, 476.21088],
[343.38132, 24.276838, 640.1404, 371.49573],
[13.225126, 54.179348, 318.98422, 472.2207],
[40.114475, 73.44104, 175.9573, 118.48469],
],
device=torch_device,
)
self.assertTrue(torch.allclose(results["scores"][:4], expected_scores, atol=1e-4))
self.assertSequenceEqual(results["labels"][:4].tolist(), expected_labels)
self.assertTrue(torch.allclose(results["boxes"][:4], expected_slice_boxes, atol=1e-4))

View File

@ -0,0 +1,130 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import RTDetrResNetConfig
from transformers.testing_utils import require_torch, torch_device
from transformers.utils.import_utils import is_torch_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_modeling_common import floats_tensor, ids_tensor
if is_torch_available():
from transformers import RTDetrResNetBackbone
class RTDetrResNetModelTester:
def __init__(
self,
parent,
batch_size=3,
image_size=32,
num_channels=3,
embeddings_size=10,
hidden_sizes=[10, 20, 30, 40],
depths=[1, 1, 2, 1],
is_training=True,
use_labels=True,
hidden_act="relu",
num_labels=3,
scope=None,
out_features=["stage2", "stage3", "stage4"],
out_indices=[2, 3, 4],
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.embeddings_size = embeddings_size
self.hidden_sizes = hidden_sizes
self.depths = depths
self.is_training = is_training
self.use_labels = use_labels
self.hidden_act = hidden_act
self.num_labels = num_labels
self.scope = scope
self.num_stages = len(hidden_sizes)
self.out_features = out_features
self.out_indices = out_indices
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return RTDetrResNetConfig(
num_channels=self.num_channels,
embeddings_size=self.embeddings_size,
hidden_sizes=self.hidden_sizes,
depths=self.depths,
hidden_act=self.hidden_act,
num_labels=self.num_labels,
out_features=self.out_features,
out_indices=self.out_indices,
)
def create_and_check_backbone(self, config, pixel_values, labels):
model = RTDetrResNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
# verify backbone works with out_features=None
config.out_features = None
model = RTDetrResNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class RTDetrResNetBackboneTest(BackboneTesterMixin, unittest.TestCase):
all_model_classes = (RTDetrResNetBackbone,) if is_torch_available() else ()
has_attentions = False
config_class = RTDetrResNetConfig
def setUp(self):
self.model_tester = RTDetrResNetModelTester(self)