mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
New model support RTDETR (#29077)
* fill out docs string in configuration75dcd3a0e8 (r1506391856)
* reduce the input image size for the tests * remove the unappropriate tests * only 5 failes exists * make style * fill up missed architecture for object detection in docs * fix auto modeling * simple fix in missing import * major change including backbone refactor and objectdetectionoutput refactor * minor fix only 4 fails left * intermediate fix * revert __init__.py * revert __init__.py * make style * fixes in pr_docs * intermediate fix * make style * two fixes * pass doctest * only one fix left * intermediate commit * all fixed * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/rt_detr/test_modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * function class above the model definition in dice_loss * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * simple fix * layernorm add config.layer_norm_eps * fix inputs_docstring * make style * simple fix * add custom coco loading test in image_processor * fix error in BaseModelOutput https://github.com/huggingface/transformers/pull/29077#discussion_r1516657790 * simple typo * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * intermediate fix * fix with load_backbone format * remove unused configuration * 3 fix test left * make style * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: Sounak Dey <dey.sounak@gmail.com> * change last_hidden_state to first index * all pass fix TO DO: minor update in comments * make fix-copies * remove deepcopy * pr_document fix * revert deepcopy due to the issue of unexpceted behavior in decoderlayer * add atol in final * add no_split_module * _no_split_modules = None * device transfer for model parallelism * minor fix * make fix-copies * fix typo * add test_image_processor with post_processing * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add config in RTDETRPredictionHead * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set lru_cache with max_size 32 * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add lru_cache import and configuration change * change the order of definition * make fix-copies * add docs and change config error * revert strange make-fix * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * test pass * fix get_clones related and remove deepcopy * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * nit for paper section * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * rename denoising related parameters * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * check the image transformation logic * make style * make style * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * pe_encoding -> positional_encoding_temperature * remove TODO * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * remove eval_idx since transformer DETR is giving all decoder output * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * change variable name * make style and docs import update * Revert "Update src/transformers/models/rt_detr/image_processing_rt_detr.py" This reverts commit74aa3e1de0
. * fix typo * add postprocessing in docs * move import scipy to top * change varaible name * make fix-copies * remove eval_idx in test * move to after first sentence * update image_processor since box loss requires normalized one * change appropriate name to auxiliary_outputs * Update src/transformers/models/rt_detr/__init__.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/__init__.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * make style * remove panoptic related comments * make style * revert valid_processor_keys * fix aux related test * make style * change origination from config to backbone API * enable the dn_loss * fix test and conversion * renewal weight initialization * change initializer_range * make fix-up * fix the loss issue in the auxiliary output and denoising part * change weight loss to original RTDETR * fix in initialization * sync shape format of dn and aux * make style * stable fine-tuning and compatible conversion for resnet101 * make style * skip input_embed * change encoder related variable * enable converting rtdetr_r101 * add r101 related conversion code * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * change name _shape to _reshape * Update src/transformers/__init__.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * maket style * make fix-copies * remove deprecated import * more fix * remove last_hidden_state for task-specific model * Revert "remove last_hidden_state for task-specific model" This reverts commitccb7a34051
. * minore change in convert * remove print * make style and fix-copies * add custom rtdetr backbone for r18, r34 * remove print * change copied * add pad_size * make style * change layertype to optional to pass the CI * make style * add test in modeling_resnet_rt_detr * make fix-copies * skip tmp file test * fix comment * add docs * change to modeling_resnet file format * enabling resnet50 above * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: Jason Wu <jasonkit@users.noreply.github.com> * enable all the rtdetr model :) * finish except CI * add RTDetrResNetBackbone * make fix-copies * fix TO DO: CI enable * make style * rename test * add docs * add special fix * revert resnet * Update src/transformers/models/rt_detr/modeling_rt_detr_resnet.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * add more comment * remove swin comment * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * rename convert and add verify backbone * Update docs/source/en/_toctree.yml Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * make style * requests for docs * more general test docs * general script docs * make fix-copies * final commit * Revert "Update src/transformers/models/rt_detr/configuration_rt_detr.py" This reverts commitd136225cd3
. * skip test_model_get_set_embeddings * remove target * add changes * make fix-copies * remove decoder_attention_mask * add load_backbone function for auto_backbone * remove comment * fix repo name * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * final commit * remove unused downsample_in_bottleneck * new test for autobackbone * change to appropriate indices * test fix * fix dict in test_image_processor * fix test * [run-slow] rt_detr, rt_detr_resnet * change the slow test * [run-slow] rt_detr * [run-slow] rt_detr, rt_detr_resnet * make in to same cuda in CSPRepLayer * [run-slow] rt_detr, rt_detr_resnet --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sounak Dey <dey.sounak@gmail.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Jason Wu <jasonkit@users.noreply.github.com> Co-authored-by: ChoiSangBum <choisangbum@ChoiSangBumui-MacBookPro.local>
This commit is contained in:
parent
8b7cd40273
commit
74a207404e
@ -627,6 +627,8 @@
|
||||
title: RegNet
|
||||
- local: model_doc/resnet
|
||||
title: ResNet
|
||||
- local: model_doc/rt_detr
|
||||
title: RT-DETR
|
||||
- local: model_doc/segformer
|
||||
title: SegFormer
|
||||
- local: model_doc/seggpt
|
||||
|
@ -262,6 +262,8 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [RoBERTa-PreLayerNorm](model_doc/roberta-prelayernorm) | ✅ | ✅ | ✅ |
|
||||
| [RoCBert](model_doc/roc_bert) | ✅ | ❌ | ❌ |
|
||||
| [RoFormer](model_doc/roformer) | ✅ | ✅ | ✅ |
|
||||
| [RT-DETR](model_doc/rt_detr) | ✅ | ❌ | ❌ |
|
||||
| [RT-DETR-ResNet](model_doc/rt_detr_resnet) | ✅ | ❌ | ❌ |
|
||||
| [RWKV](model_doc/rwkv) | ✅ | ❌ | ❌ |
|
||||
| [SAM](model_doc/sam) | ✅ | ✅ | ❌ |
|
||||
| [SeamlessM4T](model_doc/seamless_m4t) | ✅ | ❌ | ❌ |
|
||||
|
85
docs/source/en/model_doc/rt_detr.md
Normal file
85
docs/source/en/model_doc/rt_detr.md
Normal file
@ -0,0 +1,85 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# RT-DETR
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
The RT-DETR model was proposed in [DETRs Beat YOLOs on Real-time Object Detection](https://arxiv.org/abs/2304.08069) by Wenyu Lv, Yian Zhao, Shangliang Xu, Jinman Wei, Guanzhong Wang, Cheng Cui, Yuning Du, Qingqing Dang, Yi Liu.
|
||||
|
||||
RT-DETR is an object detection model that stands for "Real-Time DEtection Transformer." This model is designed to perform object detection tasks with a focus on achieving real-time performance while maintaining high accuracy. Leveraging the transformer architecture, which has gained significant popularity in various fields of deep learning, RT-DETR processes images to identify and locate multiple objects within them.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Recently, end-to-end transformer-based detectors (DETRs) have achieved remarkable performance. However, the issue of the high computational cost of DETRs has not been effectively addressed, limiting their practical application and preventing them from fully exploiting the benefits of no post-processing, such as non-maximum suppression (NMS). In this paper, we first analyze the influence of NMS in modern real-time object detectors on inference speed, and establish an end-to-end speed benchmark. To avoid the inference delay caused by NMS, we propose a Real-Time DEtection TRansformer (RT-DETR), the first real-time end-to-end object detector to our best knowledge. Specifically, we design an efficient hybrid encoder to efficiently process multi-scale features by decoupling the intra-scale interaction and cross-scale fusion, and propose IoU-aware query selection to improve the initialization of object queries. In addition, our proposed detector supports flexibly adjustment of the inference speed by using different decoder layers without the need for retraining, which facilitates the practical application of real-time object detectors. Our RT-DETR-L achieves 53.0% AP on COCO val2017 and 114 FPS on T4 GPU, while RT-DETR-X achieves 54.8% AP and 74 FPS, outperforming all YOLO detectors of the same scale in both speed and accuracy. Furthermore, our RT-DETR-R50 achieves 53.1% AP and 108 FPS, outperforming DINO-Deformable-DETR-R50 by 2.2% AP in accuracy and by about 21 times in FPS.*
|
||||
|
||||
The model version was contributed by [rafaelpadilla](https://huggingface.co/rafaelpadilla) and [sangbumchoi](https://github.com/SangbumChoi). The original code can be found [here](https://github.com/lyuwenyu/RT-DETR/).
|
||||
|
||||
|
||||
## Usage tips
|
||||
|
||||
Initially, an image is processed using a pre-trained convolutional neural network, specifically a Resnet-D variant as referenced in the original code. This network extracts features from the final three layers of the architecture. Following this, a hybrid encoder is employed to convert the multi-scale features into a sequential array of image features. Then, a decoder, equipped with auxiliary prediction heads is used to refine the object queries. This process facilitates the direct generation of bounding boxes, eliminating the need for any additional post-processing to acquire the logits and coordinates for the bounding boxes.
|
||||
|
||||
```py
|
||||
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
|
||||
from PIL import Image
|
||||
import json
|
||||
import torch
|
||||
import requests
|
||||
|
||||
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
|
||||
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
|
||||
|
||||
inputs = image_processor(images=image, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]), threshold=0.3)
|
||||
```
|
||||
|
||||
## RTDetrConfig
|
||||
|
||||
[[autodoc]] RTDetrConfig
|
||||
|
||||
## RTDetrResNetConfig
|
||||
|
||||
[[autodoc]] RTDetrResNetConfig
|
||||
|
||||
## RTDetrImageProcessor
|
||||
|
||||
[[autodoc]] RTDetrImageProcessor
|
||||
- preprocess
|
||||
- post_process_object_detection
|
||||
|
||||
## RTDetrModel
|
||||
|
||||
[[autodoc]] RTDetrModel
|
||||
- forward
|
||||
|
||||
## RTDetrForObjectDetection
|
||||
|
||||
[[autodoc]] RTDetrForObjectDetection
|
||||
- forward
|
||||
|
||||
## RTDetrResNetBackbone
|
||||
|
||||
[[autodoc]] RTDetrResNetBackbone
|
||||
- forward
|
@ -654,6 +654,7 @@ _import_structure = {
|
||||
"RoFormerConfig",
|
||||
"RoFormerTokenizer",
|
||||
],
|
||||
"models.rt_detr": ["RTDetrConfig", "RTDetrResNetConfig"],
|
||||
"models.rwkv": ["RwkvConfig"],
|
||||
"models.sam": [
|
||||
"SamConfig",
|
||||
@ -1153,6 +1154,7 @@ else:
|
||||
_import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"])
|
||||
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
|
||||
_import_structure["models.pvt"].extend(["PvtImageProcessor"])
|
||||
_import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"])
|
||||
_import_structure["models.sam"].extend(["SamImageProcessor"])
|
||||
_import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"])
|
||||
_import_structure["models.seggpt"].extend(["SegGptImageProcessor"])
|
||||
@ -3004,6 +3006,15 @@ else:
|
||||
"load_tf_weights_in_roformer",
|
||||
]
|
||||
)
|
||||
_import_structure["models.rt_detr"].extend(
|
||||
[
|
||||
"RTDetrForObjectDetection",
|
||||
"RTDetrModel",
|
||||
"RTDetrPreTrainedModel",
|
||||
"RTDetrResNetBackbone",
|
||||
"RTDetrResNetPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.rwkv"].extend(
|
||||
[
|
||||
"RwkvForCausalLM",
|
||||
@ -5270,6 +5281,10 @@ if TYPE_CHECKING:
|
||||
RoFormerConfig,
|
||||
RoFormerTokenizer,
|
||||
)
|
||||
from .models.rt_detr import (
|
||||
RTDetrConfig,
|
||||
RTDetrResNetConfig,
|
||||
)
|
||||
from .models.rwkv import RwkvConfig
|
||||
from .models.sam import (
|
||||
SamConfig,
|
||||
@ -5792,6 +5807,7 @@ if TYPE_CHECKING:
|
||||
PoolFormerImageProcessor,
|
||||
)
|
||||
from .models.pvt import PvtImageProcessor
|
||||
from .models.rt_detr import RTDetrImageProcessor
|
||||
from .models.sam import SamImageProcessor
|
||||
from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor
|
||||
from .models.seggpt import SegGptImageProcessor
|
||||
@ -7295,6 +7311,13 @@ if TYPE_CHECKING:
|
||||
RoFormerPreTrainedModel,
|
||||
load_tf_weights_in_roformer,
|
||||
)
|
||||
from .models.rt_detr import (
|
||||
RTDetrForObjectDetection,
|
||||
RTDetrModel,
|
||||
RTDetrPreTrainedModel,
|
||||
RTDetrResNetBackbone,
|
||||
RTDetrResNetPreTrainedModel,
|
||||
)
|
||||
from .models.rwkv import (
|
||||
RwkvForCausalLM,
|
||||
RwkvModel,
|
||||
|
@ -193,6 +193,7 @@ from . import (
|
||||
roberta_prelayernorm,
|
||||
roc_bert,
|
||||
roformer,
|
||||
rt_detr,
|
||||
rwkv,
|
||||
sam,
|
||||
seamless_m4t,
|
||||
|
@ -214,6 +214,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("roberta-prelayernorm", "RobertaPreLayerNormConfig"),
|
||||
("roc_bert", "RoCBertConfig"),
|
||||
("roformer", "RoFormerConfig"),
|
||||
("rt_detr", "RTDetrConfig"),
|
||||
("rt_detr_resnet", "RTDetrResNetConfig"),
|
||||
("rwkv", "RwkvConfig"),
|
||||
("sam", "SamConfig"),
|
||||
("seamless_m4t", "SeamlessM4TConfig"),
|
||||
@ -499,6 +501,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("roberta-prelayernorm", "RoBERTa-PreLayerNorm"),
|
||||
("roc_bert", "RoCBert"),
|
||||
("roformer", "RoFormer"),
|
||||
("rt_detr", "RT-DETR"),
|
||||
("rt_detr_resnet", "RT-DETR-ResNet"),
|
||||
("rwkv", "RWKV"),
|
||||
("sam", "SAM"),
|
||||
("seamless_m4t", "SeamlessM4T"),
|
||||
@ -623,6 +627,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
||||
("clip_vision_model", "clip"),
|
||||
("siglip_vision_model", "siglip"),
|
||||
("chinese_clip_vision_model", "chinese_clip"),
|
||||
("rt_detr_resnet", "rt_detr"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -114,6 +114,7 @@ else:
|
||||
("pvt_v2", ("PvtImageProcessor",)),
|
||||
("regnet", ("ConvNextImageProcessor",)),
|
||||
("resnet", ("ConvNextImageProcessor",)),
|
||||
("rt_detr", "RTDetrImageProcessor"),
|
||||
("sam", ("SamImageProcessor",)),
|
||||
("segformer", ("SegformerImageProcessor",)),
|
||||
("seggpt", ("SegGptImageProcessor",)),
|
||||
|
@ -202,6 +202,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("roberta-prelayernorm", "RobertaPreLayerNormModel"),
|
||||
("roc_bert", "RoCBertModel"),
|
||||
("roformer", "RoFormerModel"),
|
||||
("rt_detr", "RTDetrModel"),
|
||||
("rwkv", "RwkvModel"),
|
||||
("sam", "SamModel"),
|
||||
("seamless_m4t", "SeamlessM4TModel"),
|
||||
@ -765,6 +766,7 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
||||
("deformable_detr", "DeformableDetrForObjectDetection"),
|
||||
("deta", "DetaForObjectDetection"),
|
||||
("detr", "DetrForObjectDetection"),
|
||||
("rt_detr", "RTDetrForObjectDetection"),
|
||||
("table-transformer", "TableTransformerForObjectDetection"),
|
||||
("yolos", "YolosForObjectDetection"),
|
||||
]
|
||||
@ -1252,6 +1254,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
||||
("nat", "NatBackbone"),
|
||||
("pvt_v2", "PvtV2Backbone"),
|
||||
("resnet", "ResNetBackbone"),
|
||||
("rt_detr_resnet", "RTDetrResNetBackbone"),
|
||||
("swin", "SwinBackbone"),
|
||||
("swinv2", "Swinv2Backbone"),
|
||||
("timm_backbone", "TimmBackbone"),
|
||||
|
@ -29,22 +29,24 @@ from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_scipy_available,
|
||||
is_timm_available,
|
||||
is_torch_cuda_available,
|
||||
is_vision_available,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import meshgrid
|
||||
from ...utils import is_accelerate_available, is_ninja_available, logging
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_accelerate_available,
|
||||
is_ninja_available,
|
||||
is_scipy_available,
|
||||
is_timm_available,
|
||||
is_torch_cuda_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from ...utils.backbone_utils import load_backbone
|
||||
from .configuration_deformable_detr import DeformableDetrConfig
|
||||
|
||||
|
78
src/transformers/models/rt_detr/__init__.py
Normal file
78
src/transformers/models/rt_detr/__init__.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {"configuration_rt_detr": ["RTDetrConfig"], "configuration_rt_detr_resnet": ["RTDetrResNetConfig"]}
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["image_processing_rt_detr"] = ["RTDetrImageProcessor"]
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_rt_detr"] = [
|
||||
"RTDetrForObjectDetection",
|
||||
"RTDetrModel",
|
||||
"RTDetrPreTrainedModel",
|
||||
]
|
||||
_import_structure["modeling_rt_detr_resnet"] = [
|
||||
"RTDetrResNetBackbone",
|
||||
"RTDetrResNetPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_rt_detr import RTDetrConfig
|
||||
from .configuration_rt_detr_resnet import RTDetrResNetConfig
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .image_processing_rt_detr import RTDetrImageProcessor
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_rt_detr import (
|
||||
RTDetrForObjectDetection,
|
||||
RTDetrModel,
|
||||
RTDetrPreTrainedModel,
|
||||
)
|
||||
from .modeling_rt_detr_resnet import (
|
||||
RTDetrResNetBackbone,
|
||||
RTDetrResNetPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
352
src/transformers/models/rt_detr/configuration_rt_detr.py
Normal file
352
src/transformers/models/rt_detr/configuration_rt_detr.py
Normal file
@ -0,0 +1,352 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RT-DETR model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from .configuration_rt_detr_resnet import RTDetrResNetConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class RTDetrConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`RTDetrModel`]. It is used to instantiate a
|
||||
RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the RT-DETR
|
||||
[checkpoing/todo](https://huggingface.co/checkpoing/todo) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
initializer_range (`float`, *optional*, defaults to 0.01):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the batch normalization layers.
|
||||
backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
|
||||
The configuration of the backbone model.
|
||||
backbone (`str`, *optional*):
|
||||
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
||||
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
||||
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
||||
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use pretrained weights for the backbone.
|
||||
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
||||
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
||||
library.
|
||||
backbone_kwargs (`dict`, *optional*):
|
||||
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||
encoder_hidden_dim (`int`, *optional*, defaults to 256):
|
||||
Dimension of the layers in hybrid encoder.
|
||||
encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
|
||||
Multi level features input for encoder.
|
||||
feat_strides (`List[int]`, *optional*, defaults to `[8, 16, 32]`):
|
||||
Strides used in each feature map.
|
||||
encoder_layers (`int`, *optional*, defaults to 1):
|
||||
Total of layers to be used by the encoder.
|
||||
encoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
||||
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The ratio for all dropout layers.
|
||||
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
encode_proj_layers (`List[int]`, *optional*, defaults to `[2]`):
|
||||
Indexes of the projected layers to be used in the encoder.
|
||||
positional_encoding_temperature (`int`, *optional*, defaults to 10000):
|
||||
The temperature parameter used to create the positional encodings.
|
||||
encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
activation_function (`str`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
eval_size (`Tuple[int, int]`, *optional*):
|
||||
Height and width used to computes the effective height and width of the position embeddings after taking
|
||||
into account the stride.
|
||||
normalize_before (`bool`, *optional*, defaults to `False`):
|
||||
Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
|
||||
feed-forward modules.
|
||||
hidden_expansion (`float`, *optional*, defaults to 1.0):
|
||||
Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
|
||||
d_model (`int`, *optional*, defaults to 256):
|
||||
Dimension of the layers exclude hybrid encoder.
|
||||
num_queries (`int`, *optional*, defaults to 300):
|
||||
Number of object queries.
|
||||
decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
|
||||
Multi level features dimension for decoder
|
||||
decoder_ffn_dim (`int`, *optional*, defaults to 1024):
|
||||
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
num_feature_levels (`int`, *optional*, defaults to 3):
|
||||
The number of input feature levels.
|
||||
decoder_n_points (`int`, *optional*, defaults to 4):
|
||||
The number of sampled keys in each feature level for each attention head in the decoder.
|
||||
decoder_layers (`int`, *optional*, defaults to 6):
|
||||
Number of decoder layers.
|
||||
decoder_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
|
||||
The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
num_denoising (`int`, *optional*, defaults to 100):
|
||||
The total number of denoising tasks or queries to be used for contrastive denoising.
|
||||
label_noise_ratio (`float`, *optional*, defaults to 0.5):
|
||||
The fraction of denoising labels to which random noise should be added.
|
||||
box_noise_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale or magnitude of noise to be added to the bounding boxes.
|
||||
learn_initial_query (`bool`, *optional*, defaults to `False`):
|
||||
Indicates whether the initial query embeddings for the decoder should be learned during training
|
||||
anchor_image_size (`Tuple[int, int]`, *optional*, defaults to `[640, 640]`):
|
||||
Height and width of the input image used during evaluation to generate the bounding box anchors.
|
||||
disable_custom_kernels (`bool`, *optional*, defaults to `True`):
|
||||
Whether to disable custom kernels.
|
||||
with_box_refine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
|
||||
based on the predictions from the previous layer.
|
||||
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
||||
Whether the architecture has an encoder decoder structure.
|
||||
matcher_alpha (`float`, *optional*, defaults to 0.25):
|
||||
Parameter alpha used by the Hungarian Matcher.
|
||||
matcher_gamma (`float`, *optional*, defaults to 2.0):
|
||||
Parameter gamma used by the Hungarian Matcher.
|
||||
matcher_class_cost (`float`, *optional*, defaults to 2.0):
|
||||
The relative weight of the class loss used by the Hungarian Matcher.
|
||||
matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
|
||||
The relative weight of the bounding box loss used by the Hungarian Matcher.
|
||||
matcher_giou_cost (`float`, *optional*, defaults to 2.0):
|
||||
The relative weight of the giou loss of used by the Hungarian Matcher.
|
||||
use_focal_loss (`bool`, *optional*, defaults to `True`):
|
||||
Parameter informing if focal focal should be used.
|
||||
auxiliary_loss (`bool`, *optional*, defaults to `True`):
|
||||
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
focal_loss_alpha (`float`, *optional*, defaults to 0.75):
|
||||
Parameter alpha used to compute the focal loss.
|
||||
focal_loss_gamma (`float`, *optional*, defaults to 2.0):
|
||||
Parameter gamma used to compute the focal loss.
|
||||
weight_loss_vfl (`float`, *optional*, defaults to 1.0):
|
||||
Relative weight of the varifocal loss in the object detection loss.
|
||||
weight_loss_bbox (`float`, *optional*, defaults to 5.0):
|
||||
Relative weight of the L1 bounding box loss in the object detection loss.
|
||||
weight_loss_giou (`float`, *optional*, defaults to 2.0):
|
||||
Relative weight of the generalized IoU loss in the object detection loss.
|
||||
eos_coefficient (`float`, *optional*, defaults to 0.0001):
|
||||
Relative classification weight of the 'no-object' class in the object detection loss.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import RTDetrConfig, RTDetrModel
|
||||
|
||||
>>> # Initializing a RT-DETR configuration
|
||||
>>> configuration = RTDetrConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the configuration
|
||||
>>> model = RTDetrModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "rt_detr"
|
||||
layer_types = ["basic", "bottleneck"]
|
||||
attribute_map = {
|
||||
"hidden_size": "d_model",
|
||||
"num_attention_heads": "encoder_attention_heads",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initializer_range=0.01,
|
||||
layer_norm_eps=1e-5,
|
||||
batch_norm_eps=1e-5,
|
||||
# backbone
|
||||
backbone_config=None,
|
||||
backbone=None,
|
||||
use_pretrained_backbone=False,
|
||||
use_timm_backbone=False,
|
||||
backbone_kwargs=None,
|
||||
# encoder HybridEncoder
|
||||
encoder_hidden_dim=256,
|
||||
encoder_in_channels=[512, 1024, 2048],
|
||||
feat_strides=[8, 16, 32],
|
||||
encoder_layers=1,
|
||||
encoder_ffn_dim=1024,
|
||||
encoder_attention_heads=8,
|
||||
dropout=0.0,
|
||||
activation_dropout=0.0,
|
||||
encode_proj_layers=[2],
|
||||
positional_encoding_temperature=10000,
|
||||
encoder_activation_function="gelu",
|
||||
activation_function="silu",
|
||||
eval_size=None,
|
||||
normalize_before=False,
|
||||
hidden_expansion=1.0,
|
||||
# decoder RTDetrTransformer
|
||||
d_model=256,
|
||||
num_queries=300,
|
||||
decoder_in_channels=[256, 256, 256],
|
||||
decoder_ffn_dim=1024,
|
||||
num_feature_levels=3,
|
||||
decoder_n_points=4,
|
||||
decoder_layers=6,
|
||||
decoder_attention_heads=8,
|
||||
decoder_activation_function="relu",
|
||||
attention_dropout=0.0,
|
||||
num_denoising=100,
|
||||
label_noise_ratio=0.5,
|
||||
box_noise_scale=1.0,
|
||||
learn_initial_query=False,
|
||||
anchor_image_size=[640, 640],
|
||||
disable_custom_kernels=True,
|
||||
with_box_refine=True,
|
||||
is_encoder_decoder=True,
|
||||
# Loss
|
||||
matcher_alpha=0.25,
|
||||
matcher_gamma=2.0,
|
||||
matcher_class_cost=2.0,
|
||||
matcher_bbox_cost=5.0,
|
||||
matcher_giou_cost=2.0,
|
||||
use_focal_loss=True,
|
||||
auxiliary_loss=True,
|
||||
focal_loss_alpha=0.75,
|
||||
focal_loss_gamma=2.0,
|
||||
weight_loss_vfl=1.0,
|
||||
weight_loss_bbox=5.0,
|
||||
weight_loss_giou=2.0,
|
||||
eos_coefficient=1e-4,
|
||||
**kwargs,
|
||||
):
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.batch_norm_eps = batch_norm_eps
|
||||
# backbone
|
||||
if backbone_config is None and backbone is None:
|
||||
logger.info(
|
||||
"`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetr-ResNet` backbone."
|
||||
)
|
||||
backbone_config = RTDetrResNetConfig(
|
||||
num_channels=3,
|
||||
embedding_size=64,
|
||||
hidden_sizes=[256, 512, 1024, 2048],
|
||||
depths=[3, 4, 6, 3],
|
||||
layer_type="bottleneck",
|
||||
hidden_act="relu",
|
||||
downsample_in_first_stage=False,
|
||||
downsample_in_bottleneck=False,
|
||||
out_features=None,
|
||||
out_indices=[2, 3, 4],
|
||||
)
|
||||
elif isinstance(backbone_config, dict):
|
||||
backbone_model_type = backbone_config.pop("model_type")
|
||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||
backbone_config = config_class.from_dict(backbone_config)
|
||||
|
||||
verify_backbone_config_arguments(
|
||||
use_timm_backbone=use_timm_backbone,
|
||||
use_pretrained_backbone=use_pretrained_backbone,
|
||||
backbone=backbone,
|
||||
backbone_config=backbone_config,
|
||||
backbone_kwargs=backbone_kwargs,
|
||||
)
|
||||
|
||||
self.backbone_config = backbone_config
|
||||
self.backbone = backbone
|
||||
self.use_pretrained_backbone = use_pretrained_backbone
|
||||
self.use_timm_backbone = use_timm_backbone
|
||||
self.backbone_kwargs = backbone_kwargs
|
||||
# encoder
|
||||
self.encoder_hidden_dim = encoder_hidden_dim
|
||||
self.encoder_in_channels = encoder_in_channels
|
||||
self.feat_strides = feat_strides
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.dropout = dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.encode_proj_layers = encode_proj_layers
|
||||
self.encoder_layers = encoder_layers
|
||||
self.positional_encoding_temperature = positional_encoding_temperature
|
||||
self.eval_size = eval_size
|
||||
self.normalize_before = normalize_before
|
||||
self.encoder_activation_function = encoder_activation_function
|
||||
self.activation_function = activation_function
|
||||
self.hidden_expansion = hidden_expansion
|
||||
# decoder
|
||||
self.d_model = d_model
|
||||
self.num_queries = num_queries
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.decoder_in_channels = decoder_in_channels
|
||||
self.num_feature_levels = num_feature_levels
|
||||
self.decoder_n_points = decoder_n_points
|
||||
self.decoder_layers = decoder_layers
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.decoder_activation_function = decoder_activation_function
|
||||
self.attention_dropout = attention_dropout
|
||||
self.num_denoising = num_denoising
|
||||
self.label_noise_ratio = label_noise_ratio
|
||||
self.box_noise_scale = box_noise_scale
|
||||
self.learn_initial_query = learn_initial_query
|
||||
self.anchor_image_size = anchor_image_size
|
||||
self.auxiliary_loss = auxiliary_loss
|
||||
self.disable_custom_kernels = disable_custom_kernels
|
||||
self.with_box_refine = with_box_refine
|
||||
# Loss
|
||||
self.matcher_alpha = matcher_alpha
|
||||
self.matcher_gamma = matcher_gamma
|
||||
self.matcher_class_cost = matcher_class_cost
|
||||
self.matcher_bbox_cost = matcher_bbox_cost
|
||||
self.matcher_giou_cost = matcher_giou_cost
|
||||
self.use_focal_loss = use_focal_loss
|
||||
self.focal_loss_alpha = focal_loss_alpha
|
||||
self.focal_loss_gamma = focal_loss_gamma
|
||||
self.weight_loss_vfl = weight_loss_vfl
|
||||
self.weight_loss_bbox = weight_loss_bbox
|
||||
self.weight_loss_giou = weight_loss_giou
|
||||
self.eos_coefficient = eos_coefficient
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
||||
|
||||
@classmethod
|
||||
def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
|
||||
"""Instantiate a [`RTDetrConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
|
||||
configuration.
|
||||
|
||||
Args:
|
||||
backbone_config ([`PretrainedConfig`]):
|
||||
The backbone configuration.
|
||||
|
||||
Returns:
|
||||
[`RTDetrConfig`]: An instance of a configuration object
|
||||
"""
|
||||
return cls(
|
||||
backbone_config=backbone_config,
|
||||
**kwargs,
|
||||
)
|
111
src/transformers/models/rt_detr/configuration_rt_detr_resnet.py
Normal file
111
src/transformers/models/rt_detr/configuration_rt_detr_resnet.py
Normal file
@ -0,0 +1,111 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RT-DETR ResNet model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class RTDetrResNetConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`RTDetrResnetBackbone`]. It is used to instantiate an
|
||||
ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the ResNet
|
||||
[microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
embedding_size (`int`, *optional*, defaults to 64):
|
||||
Dimensionality (hidden size) for the embedding layer.
|
||||
hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`):
|
||||
Dimensionality (hidden size) at each stage.
|
||||
depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`):
|
||||
Depth (number of layers) for each stage.
|
||||
layer_type (`str`, *optional*, defaults to `"bottleneck"`):
|
||||
The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or
|
||||
`"bottleneck"` (used for larger models like resnet-50 and above).
|
||||
hidden_act (`str`, *optional*, defaults to `"relu"`):
|
||||
The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"`
|
||||
are supported.
|
||||
downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, the first stage will downsample the inputs using a `stride` of 2.
|
||||
downsample_in_bottleneck (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a `stride` of 2.
|
||||
out_features (`List[str]`, *optional*):
|
||||
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
|
||||
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
|
||||
same order as defined in the `stage_names` attribute.
|
||||
out_indices (`List[int]`, *optional*):
|
||||
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
||||
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
|
||||
If unset and `out_features` is unset, will default to the last stage. Must be in the
|
||||
same order as defined in the `stage_names` attribute.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import RTDetrResNetConfig, RTDetrResnetBackbone
|
||||
|
||||
>>> # Initializing a ResNet resnet-50 style configuration
|
||||
>>> configuration = RTDetrResNetConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the resnet-50 style configuration
|
||||
>>> model = RTDetrResnetBackbone(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
|
||||
model_type = "rt_detr_resnet"
|
||||
layer_types = ["basic", "bottleneck"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels=3,
|
||||
embedding_size=64,
|
||||
hidden_sizes=[256, 512, 1024, 2048],
|
||||
depths=[3, 4, 6, 3],
|
||||
layer_type="bottleneck",
|
||||
hidden_act="relu",
|
||||
downsample_in_first_stage=False,
|
||||
downsample_in_bottleneck=False,
|
||||
out_features=None,
|
||||
out_indices=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if layer_type not in self.layer_types:
|
||||
raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}")
|
||||
self.num_channels = num_channels
|
||||
self.embedding_size = embedding_size
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.depths = depths
|
||||
self.layer_type = layer_type
|
||||
self.hidden_act = hidden_act
|
||||
self.downsample_in_first_stage = downsample_in_first_stage
|
||||
self.downsample_in_bottleneck = downsample_in_bottleneck
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
@ -0,0 +1,782 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert RT Detr checkpoints with Timm backbone"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from transformers import RTDetrConfig, RTDetrForObjectDetection, RTDetrImageProcessor
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_rt_detr_config(model_name: str) -> RTDetrConfig:
|
||||
config = RTDetrConfig()
|
||||
|
||||
config.num_labels = 80
|
||||
repo_id = "huggingface/label-files"
|
||||
filename = "coco-detection-mmdet-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
if model_name == "rtdetr_r18vd":
|
||||
config.backbone_config.hidden_sizes = [64, 128, 256, 512]
|
||||
config.backbone_config.depths = [2, 2, 2, 2]
|
||||
config.backbone_config.layer_type = "basic"
|
||||
config.encoder_in_channels = [128, 256, 512]
|
||||
config.hidden_expansion = 0.5
|
||||
config.decoder_layers = 3
|
||||
elif model_name == "rtdetr_r34vd":
|
||||
config.backbone_config.hidden_sizes = [64, 128, 256, 512]
|
||||
config.backbone_config.depths = [3, 4, 6, 3]
|
||||
config.backbone_config.layer_type = "basic"
|
||||
config.encoder_in_channels = [128, 256, 512]
|
||||
config.hidden_expansion = 0.5
|
||||
config.decoder_layers = 4
|
||||
elif model_name == "rtdetr_r50vd_m":
|
||||
pass
|
||||
elif model_name == "rtdetr_r50vd":
|
||||
pass
|
||||
elif model_name == "rtdetr_r101vd":
|
||||
config.backbone_config.depths = [3, 4, 23, 3]
|
||||
config.encoder_ffn_dim = 2048
|
||||
config.encoder_hidden_dim = 384
|
||||
config.decoder_in_channels = [384, 384, 384]
|
||||
elif model_name == "rtdetr_r18vd_coco_o365":
|
||||
config.backbone_config.hidden_sizes = [64, 128, 256, 512]
|
||||
config.backbone_config.depths = [2, 2, 2, 2]
|
||||
config.backbone_config.layer_type = "basic"
|
||||
config.encoder_in_channels = [128, 256, 512]
|
||||
config.hidden_expansion = 0.5
|
||||
config.decoder_layers = 3
|
||||
elif model_name == "rtdetr_r50vd_coco_o365":
|
||||
pass
|
||||
elif model_name == "rtdetr_r101vd_coco_o365":
|
||||
config.backbone_config.depths = [3, 4, 23, 3]
|
||||
config.encoder_ffn_dim = 2048
|
||||
config.encoder_hidden_dim = 384
|
||||
config.decoder_in_channels = [384, 384, 384]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_rename_keys(config):
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
rename_keys = []
|
||||
|
||||
# stem
|
||||
# fmt: off
|
||||
last_key = ["weight", "bias", "running_mean", "running_var"]
|
||||
|
||||
for level in range(3):
|
||||
rename_keys.append((f"backbone.conv1.conv1_{level+1}.conv.weight", f"model.backbone.model.embedder.embedder.{level}.convolution.weight"))
|
||||
for last in last_key:
|
||||
rename_keys.append((f"backbone.conv1.conv1_{level+1}.norm.{last}", f"model.backbone.model.embedder.embedder.{level}.normalization.{last}"))
|
||||
|
||||
for stage_idx in range(len(config.backbone_config.depths)):
|
||||
for layer_idx in range(config.backbone_config.depths[stage_idx]):
|
||||
# shortcut
|
||||
if layer_idx == 0:
|
||||
if stage_idx == 0:
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.weight",
|
||||
f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.convolution.weight",
|
||||
)
|
||||
)
|
||||
for last in last_key:
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res_layers.{stage_idx}.blocks.0.short.norm.{last}",
|
||||
f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.normalization.{last}",
|
||||
)
|
||||
)
|
||||
else:
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.conv.weight",
|
||||
f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.1.convolution.weight",
|
||||
)
|
||||
)
|
||||
for last in last_key:
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res_layers.{stage_idx}.blocks.0.short.conv.norm.{last}",
|
||||
f"model.backbone.model.encoder.stages.{stage_idx}.layers.0.shortcut.1.normalization.{last}",
|
||||
)
|
||||
)
|
||||
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2a.conv.weight",
|
||||
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.convolution.weight",
|
||||
)
|
||||
)
|
||||
for last in last_key:
|
||||
rename_keys.append((
|
||||
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2a.norm.{last}",
|
||||
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.0.normalization.{last}",
|
||||
))
|
||||
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2b.conv.weight",
|
||||
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.convolution.weight",
|
||||
)
|
||||
)
|
||||
for last in last_key:
|
||||
rename_keys.append((
|
||||
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2b.norm.{last}",
|
||||
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.1.normalization.{last}",
|
||||
))
|
||||
|
||||
# https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/nn/backbone/presnet.py#L171
|
||||
if config.backbone_config.layer_type != "basic":
|
||||
rename_keys.append(
|
||||
(
|
||||
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2c.conv.weight",
|
||||
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.2.convolution.weight",
|
||||
)
|
||||
)
|
||||
for last in last_key:
|
||||
rename_keys.append((
|
||||
f"backbone.res_layers.{stage_idx}.blocks.{layer_idx}.branch2c.norm.{last}",
|
||||
f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.2.normalization.{last}",
|
||||
))
|
||||
# fmt: on
|
||||
|
||||
for i in range(config.encoder_layers):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.encoder.{i}.layers.0.self_attn.out_proj.weight",
|
||||
f"model.encoder.encoder.{i}.layers.0.self_attn.out_proj.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.encoder.{i}.layers.0.self_attn.out_proj.bias",
|
||||
f"model.encoder.encoder.{i}.layers.0.self_attn.out_proj.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.encoder.{i}.layers.0.linear1.weight",
|
||||
f"model.encoder.encoder.{i}.layers.0.fc1.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.encoder.{i}.layers.0.linear1.bias",
|
||||
f"model.encoder.encoder.{i}.layers.0.fc1.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.encoder.{i}.layers.0.linear2.weight",
|
||||
f"model.encoder.encoder.{i}.layers.0.fc2.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.encoder.{i}.layers.0.linear2.bias",
|
||||
f"model.encoder.encoder.{i}.layers.0.fc2.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.encoder.{i}.layers.0.norm1.weight",
|
||||
f"model.encoder.encoder.{i}.layers.0.self_attn_layer_norm.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.encoder.{i}.layers.0.norm1.bias",
|
||||
f"model.encoder.encoder.{i}.layers.0.self_attn_layer_norm.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.encoder.{i}.layers.0.norm2.weight",
|
||||
f"model.encoder.encoder.{i}.layers.0.final_layer_norm.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.encoder.{i}.layers.0.norm2.bias",
|
||||
f"model.encoder.encoder.{i}.layers.0.final_layer_norm.bias",
|
||||
)
|
||||
)
|
||||
|
||||
for j in range(0, 3):
|
||||
rename_keys.append((f"encoder.input_proj.{j}.0.weight", f"model.encoder_input_proj.{j}.0.weight"))
|
||||
for last in last_key:
|
||||
rename_keys.append((f"encoder.input_proj.{j}.1.{last}", f"model.encoder_input_proj.{j}.1.{last}"))
|
||||
|
||||
block_levels = 3 if config.backbone_config.layer_type != "basic" else 4
|
||||
|
||||
for i in range(len(config.encoder_in_channels) - 1):
|
||||
# encoder layers: hybridencoder parts
|
||||
for j in range(1, block_levels):
|
||||
rename_keys.append(
|
||||
(f"encoder.fpn_blocks.{i}.conv{j}.conv.weight", f"model.encoder.fpn_blocks.{i}.conv{j}.conv.weight")
|
||||
)
|
||||
for last in last_key:
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.fpn_blocks.{i}.conv{j}.norm.{last}",
|
||||
f"model.encoder.fpn_blocks.{i}.conv{j}.norm.{last}",
|
||||
)
|
||||
)
|
||||
|
||||
rename_keys.append((f"encoder.lateral_convs.{i}.conv.weight", f"model.encoder.lateral_convs.{i}.conv.weight"))
|
||||
for last in last_key:
|
||||
rename_keys.append(
|
||||
(f"encoder.lateral_convs.{i}.norm.{last}", f"model.encoder.lateral_convs.{i}.norm.{last}")
|
||||
)
|
||||
|
||||
for j in range(3):
|
||||
for k in range(1, 3):
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight",
|
||||
f"model.encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight",
|
||||
)
|
||||
)
|
||||
for last in last_key:
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}",
|
||||
f"model.encoder.fpn_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}",
|
||||
)
|
||||
)
|
||||
|
||||
for j in range(1, block_levels):
|
||||
rename_keys.append(
|
||||
(f"encoder.pan_blocks.{i}.conv{j}.conv.weight", f"model.encoder.pan_blocks.{i}.conv{j}.conv.weight")
|
||||
)
|
||||
for last in last_key:
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.pan_blocks.{i}.conv{j}.norm.{last}",
|
||||
f"model.encoder.pan_blocks.{i}.conv{j}.norm.{last}",
|
||||
)
|
||||
)
|
||||
|
||||
for j in range(3):
|
||||
for k in range(1, 3):
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight",
|
||||
f"model.encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.conv.weight",
|
||||
)
|
||||
)
|
||||
for last in last_key:
|
||||
rename_keys.append(
|
||||
(
|
||||
f"encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}",
|
||||
f"model.encoder.pan_blocks.{i}.bottlenecks.{j}.conv{k}.norm.{last}",
|
||||
)
|
||||
)
|
||||
|
||||
rename_keys.append(
|
||||
(f"encoder.downsample_convs.{i}.conv.weight", f"model.encoder.downsample_convs.{i}.conv.weight")
|
||||
)
|
||||
for last in last_key:
|
||||
rename_keys.append(
|
||||
(f"encoder.downsample_convs.{i}.norm.{last}", f"model.encoder.downsample_convs.{i}.norm.{last}")
|
||||
)
|
||||
|
||||
for i in range(config.decoder_layers):
|
||||
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.decoder.layers.{i}.self_attn.out_proj.weight",
|
||||
f"model.decoder.layers.{i}.self_attn.out_proj.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.decoder.layers.{i}.self_attn.out_proj.bias",
|
||||
f"model.decoder.layers.{i}.self_attn.out_proj.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.decoder.layers.{i}.cross_attn.sampling_offsets.weight",
|
||||
f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.decoder.layers.{i}.cross_attn.sampling_offsets.bias",
|
||||
f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.decoder.layers.{i}.cross_attn.attention_weights.weight",
|
||||
f"model.decoder.layers.{i}.encoder_attn.attention_weights.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.decoder.layers.{i}.cross_attn.attention_weights.bias",
|
||||
f"model.decoder.layers.{i}.encoder_attn.attention_weights.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.decoder.layers.{i}.cross_attn.value_proj.weight",
|
||||
f"model.decoder.layers.{i}.encoder_attn.value_proj.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.decoder.layers.{i}.cross_attn.value_proj.bias",
|
||||
f"model.decoder.layers.{i}.encoder_attn.value_proj.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.decoder.layers.{i}.cross_attn.output_proj.weight",
|
||||
f"model.decoder.layers.{i}.encoder_attn.output_proj.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.decoder.layers.{i}.cross_attn.output_proj.bias",
|
||||
f"model.decoder.layers.{i}.encoder_attn.output_proj.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"decoder.decoder.layers.{i}.norm1.weight", f"model.decoder.layers.{i}.self_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"decoder.decoder.layers.{i}.norm1.bias", f"model.decoder.layers.{i}.self_attn_layer_norm.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"decoder.decoder.layers.{i}.norm2.weight", f"model.decoder.layers.{i}.encoder_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"decoder.decoder.layers.{i}.norm2.bias", f"model.decoder.layers.{i}.encoder_attn_layer_norm.bias")
|
||||
)
|
||||
rename_keys.append((f"decoder.decoder.layers.{i}.linear1.weight", f"model.decoder.layers.{i}.fc1.weight"))
|
||||
rename_keys.append((f"decoder.decoder.layers.{i}.linear1.bias", f"model.decoder.layers.{i}.fc1.bias"))
|
||||
rename_keys.append((f"decoder.decoder.layers.{i}.linear2.weight", f"model.decoder.layers.{i}.fc2.weight"))
|
||||
rename_keys.append((f"decoder.decoder.layers.{i}.linear2.bias", f"model.decoder.layers.{i}.fc2.bias"))
|
||||
rename_keys.append(
|
||||
(f"decoder.decoder.layers.{i}.norm3.weight", f"model.decoder.layers.{i}.final_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"decoder.decoder.layers.{i}.norm3.bias", f"model.decoder.layers.{i}.final_layer_norm.bias")
|
||||
)
|
||||
|
||||
for i in range(config.decoder_layers):
|
||||
# decoder + class and bounding box heads
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.dec_score_head.{i}.weight",
|
||||
f"model.decoder.class_embed.{i}.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.dec_score_head.{i}.bias",
|
||||
f"model.decoder.class_embed.{i}.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.dec_bbox_head.{i}.layers.0.weight",
|
||||
f"model.decoder.bbox_embed.{i}.layers.0.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.dec_bbox_head.{i}.layers.0.bias",
|
||||
f"model.decoder.bbox_embed.{i}.layers.0.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.dec_bbox_head.{i}.layers.1.weight",
|
||||
f"model.decoder.bbox_embed.{i}.layers.1.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.dec_bbox_head.{i}.layers.1.bias",
|
||||
f"model.decoder.bbox_embed.{i}.layers.1.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.dec_bbox_head.{i}.layers.2.weight",
|
||||
f"model.decoder.bbox_embed.{i}.layers.2.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.dec_bbox_head.{i}.layers.2.bias",
|
||||
f"model.decoder.bbox_embed.{i}.layers.2.bias",
|
||||
)
|
||||
)
|
||||
|
||||
# decoder projection
|
||||
for i in range(len(config.decoder_in_channels)):
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.input_proj.{i}.conv.weight",
|
||||
f"model.decoder_input_proj.{i}.0.weight",
|
||||
)
|
||||
)
|
||||
for last in last_key:
|
||||
rename_keys.append(
|
||||
(
|
||||
f"decoder.input_proj.{i}.norm.{last}",
|
||||
f"model.decoder_input_proj.{i}.1.{last}",
|
||||
)
|
||||
)
|
||||
|
||||
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
|
||||
rename_keys.extend(
|
||||
[
|
||||
("decoder.denoising_class_embed.weight", "model.denoising_class_embed.weight"),
|
||||
("decoder.query_pos_head.layers.0.weight", "model.decoder.query_pos_head.layers.0.weight"),
|
||||
("decoder.query_pos_head.layers.0.bias", "model.decoder.query_pos_head.layers.0.bias"),
|
||||
("decoder.query_pos_head.layers.1.weight", "model.decoder.query_pos_head.layers.1.weight"),
|
||||
("decoder.query_pos_head.layers.1.bias", "model.decoder.query_pos_head.layers.1.bias"),
|
||||
("decoder.enc_output.0.weight", "model.enc_output.0.weight"),
|
||||
("decoder.enc_output.0.bias", "model.enc_output.0.bias"),
|
||||
("decoder.enc_output.1.weight", "model.enc_output.1.weight"),
|
||||
("decoder.enc_output.1.bias", "model.enc_output.1.bias"),
|
||||
("decoder.enc_score_head.weight", "model.enc_score_head.weight"),
|
||||
("decoder.enc_score_head.bias", "model.enc_score_head.bias"),
|
||||
("decoder.enc_bbox_head.layers.0.weight", "model.enc_bbox_head.layers.0.weight"),
|
||||
("decoder.enc_bbox_head.layers.0.bias", "model.enc_bbox_head.layers.0.bias"),
|
||||
("decoder.enc_bbox_head.layers.1.weight", "model.enc_bbox_head.layers.1.weight"),
|
||||
("decoder.enc_bbox_head.layers.1.bias", "model.enc_bbox_head.layers.1.bias"),
|
||||
("decoder.enc_bbox_head.layers.2.weight", "model.enc_bbox_head.layers.2.weight"),
|
||||
("decoder.enc_bbox_head.layers.2.bias", "model.enc_bbox_head.layers.2.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
return rename_keys
|
||||
|
||||
|
||||
def rename_key(state_dict, old, new):
|
||||
try:
|
||||
val = state_dict.pop(old)
|
||||
state_dict[new] = val
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def read_in_q_k_v(state_dict, config):
|
||||
prefix = ""
|
||||
encoder_hidden_dim = config.encoder_hidden_dim
|
||||
|
||||
# first: transformer encoder
|
||||
for i in range(config.encoder_layers):
|
||||
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop(f"{prefix}encoder.encoder.{i}.layers.0.self_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.weight"] = in_proj_weight[
|
||||
:encoder_hidden_dim, :
|
||||
]
|
||||
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.q_proj.bias"] = in_proj_bias[:encoder_hidden_dim]
|
||||
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.weight"] = in_proj_weight[
|
||||
encoder_hidden_dim : 2 * encoder_hidden_dim, :
|
||||
]
|
||||
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.k_proj.bias"] = in_proj_bias[
|
||||
encoder_hidden_dim : 2 * encoder_hidden_dim
|
||||
]
|
||||
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.weight"] = in_proj_weight[
|
||||
-encoder_hidden_dim:, :
|
||||
]
|
||||
state_dict[f"model.encoder.encoder.{i}.layers.0.self_attn.v_proj.bias"] = in_proj_bias[-encoder_hidden_dim:]
|
||||
# next: transformer decoder (which is a bit more complex because it also includes cross-attention)
|
||||
for i in range(config.decoder_layers):
|
||||
# read in weights + bias of input projection layer of self-attention
|
||||
in_proj_weight = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop(f"{prefix}decoder.decoder.layers.{i}.self_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
|
||||
state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
|
||||
state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
|
||||
state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
|
||||
state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
|
||||
state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_rt_detr_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, repo_id):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our RTDETR structure.
|
||||
"""
|
||||
|
||||
# load default config
|
||||
config = get_rt_detr_config(model_name)
|
||||
|
||||
# load original model from torch hub
|
||||
model_name_to_checkpoint_url = {
|
||||
"rtdetr_r18vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_dec3_6x_coco_from_paddle.pth",
|
||||
"rtdetr_r34vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r34vd_dec4_6x_coco_from_paddle.pth",
|
||||
"rtdetr_r50vd_m": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_m_6x_coco_from_paddle.pth",
|
||||
"rtdetr_r50vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_6x_coco_from_paddle.pth",
|
||||
"rtdetr_r101vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r101vd_6x_coco_from_paddle.pth",
|
||||
"rtdetr_r18vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_5x_coco_objects365_from_paddle.pth",
|
||||
"rtdetr_r50vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_2x_coco_objects365_from_paddle.pth",
|
||||
"rtdetr_r101vd_coco_o365": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r101vd_2x_coco_objects365_from_paddle.pth",
|
||||
}
|
||||
logger.info(f"Converting model {model_name}...")
|
||||
state_dict = torch.hub.load_state_dict_from_url(model_name_to_checkpoint_url[model_name], map_location="cpu")[
|
||||
"ema"
|
||||
]["module"]
|
||||
|
||||
# rename keys
|
||||
for src, dest in create_rename_keys(config):
|
||||
rename_key(state_dict, src, dest)
|
||||
# query, key and value matrices need special treatment
|
||||
read_in_q_k_v(state_dict, config)
|
||||
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
|
||||
for key in state_dict.copy().keys():
|
||||
if key.endswith("num_batches_tracked"):
|
||||
del state_dict[key]
|
||||
# for two_stage
|
||||
if "bbox_embed" in key or ("class_embed" in key and "denoising_" not in key):
|
||||
state_dict[key.split("model.decoder.")[-1]] = state_dict[key]
|
||||
|
||||
# finally, create HuggingFace model and load state dict
|
||||
model = RTDetrForObjectDetection(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
# load image processor
|
||||
image_processor = RTDetrImageProcessor()
|
||||
|
||||
# prepare image
|
||||
img = prepare_img()
|
||||
|
||||
# preprocess image
|
||||
transformations = transforms.Compose(
|
||||
[
|
||||
transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
original_pixel_values = transformations(img).unsqueeze(0) # insert batch dimension
|
||||
|
||||
encoding = image_processor(images=img, return_tensors="pt")
|
||||
pixel_values = encoding["pixel_values"]
|
||||
|
||||
assert torch.allclose(original_pixel_values, pixel_values)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
pixel_values = pixel_values.to(device)
|
||||
|
||||
# Pass image by the model
|
||||
outputs = model(pixel_values)
|
||||
|
||||
if model_name == "rtdetr_r18vd":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[
|
||||
[-4.3364253, -6.465683, -3.6130402],
|
||||
[-4.083815, -6.4039373, -6.97881],
|
||||
[-4.192215, -7.3410473, -6.9027247],
|
||||
]
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[
|
||||
[0.16868353, 0.19833282, 0.21182671],
|
||||
[0.25559652, 0.55121744, 0.47988364],
|
||||
[0.7698693, 0.4124569, 0.46036878],
|
||||
]
|
||||
)
|
||||
elif model_name == "rtdetr_r34vd":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[
|
||||
[-4.3727384, -4.7921476, -5.7299604],
|
||||
[-4.840536, -8.455345, -4.1745796],
|
||||
[-4.1277084, -5.2154565, -5.7852697],
|
||||
]
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[
|
||||
[0.258278, 0.5497808, 0.4732004],
|
||||
[0.16889669, 0.19890057, 0.21138911],
|
||||
[0.76632994, 0.4147879, 0.46851268],
|
||||
]
|
||||
)
|
||||
elif model_name == "rtdetr_r50vd_m":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[
|
||||
[-4.319764, -6.1349025, -6.094794],
|
||||
[-5.1056995, -7.744766, -4.803956],
|
||||
[-4.7685347, -7.9278393, -4.5751696],
|
||||
]
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[
|
||||
[0.2582739, 0.55071366, 0.47660282],
|
||||
[0.16811174, 0.19954777, 0.21292639],
|
||||
[0.54986024, 0.2752091, 0.0561416],
|
||||
]
|
||||
)
|
||||
elif model_name == "rtdetr_r50vd":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[
|
||||
[-4.6476398, -5.001154, -4.9785104],
|
||||
[-4.1593494, -4.7038546, -5.946485],
|
||||
[-4.4374595, -4.658361, -6.2352347],
|
||||
]
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[
|
||||
[0.16880608, 0.19992264, 0.21225442],
|
||||
[0.76837635, 0.4122631, 0.46368608],
|
||||
[0.2595386, 0.5483334, 0.4777486],
|
||||
]
|
||||
)
|
||||
elif model_name == "rtdetr_r101vd":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[
|
||||
[-4.6162, -4.9189, -4.6656],
|
||||
[-4.4701, -4.4997, -4.9659],
|
||||
[-5.6641, -7.9000, -5.0725],
|
||||
]
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[
|
||||
[0.7707, 0.4124, 0.4585],
|
||||
[0.2589, 0.5492, 0.4735],
|
||||
[0.1688, 0.1993, 0.2108],
|
||||
]
|
||||
)
|
||||
elif model_name == "rtdetr_r18vd_coco_o365":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[
|
||||
[-4.8726, -5.9066, -5.2450],
|
||||
[-4.8157, -6.8764, -5.1656],
|
||||
[-4.7492, -5.7006, -5.1333],
|
||||
]
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[
|
||||
[0.2552, 0.5501, 0.4773],
|
||||
[0.1685, 0.1986, 0.2104],
|
||||
[0.7692, 0.4141, 0.4620],
|
||||
]
|
||||
)
|
||||
elif model_name == "rtdetr_r50vd_coco_o365":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[
|
||||
[-4.6491, -3.9252, -5.3163],
|
||||
[-4.1386, -5.0348, -3.9016],
|
||||
[-4.4778, -4.5423, -5.7356],
|
||||
]
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[
|
||||
[0.2583, 0.5492, 0.4747],
|
||||
[0.5501, 0.2754, 0.0574],
|
||||
[0.7693, 0.4137, 0.4613],
|
||||
]
|
||||
)
|
||||
elif model_name == "rtdetr_r101vd_coco_o365":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[
|
||||
[-4.5152, -5.6811, -5.7311],
|
||||
[-4.5358, -7.2422, -5.0941],
|
||||
[-4.6919, -5.5834, -6.0145],
|
||||
]
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[
|
||||
[0.7703, 0.4140, 0.4583],
|
||||
[0.1686, 0.1991, 0.2107],
|
||||
[0.2570, 0.5496, 0.4750],
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown rt_detr_name: {model_name}")
|
||||
|
||||
assert torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits.to(outputs.logits.device), atol=1e-4)
|
||||
assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes.to(outputs.pred_boxes.device), atol=1e-3)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
||||
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
# Upload model, image processor and config to the hub
|
||||
logger.info("Uploading PyTorch model and image processor to the hub...")
|
||||
config.push_to_hub(
|
||||
repo_id=repo_id, commit_message="Add config from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py"
|
||||
)
|
||||
model.push_to_hub(
|
||||
repo_id=repo_id, commit_message="Add model from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py"
|
||||
)
|
||||
image_processor.push_to_hub(
|
||||
repo_id=repo_id,
|
||||
commit_message="Add image processor from convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="rtdetr_r50vd",
|
||||
type=str,
|
||||
help="model_name of the checkpoint you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to the hub or not.")
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
help="repo_id where the model will be pushed to.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_rt_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.repo_id)
|
1120
src/transformers/models/rt_detr/image_processing_rt_detr.py
Normal file
1120
src/transformers/models/rt_detr/image_processing_rt_detr.py
Normal file
File diff suppressed because it is too large
Load Diff
2675
src/transformers/models/rt_detr/modeling_rt_detr.py
Normal file
2675
src/transformers/models/rt_detr/modeling_rt_detr.py
Normal file
File diff suppressed because it is too large
Load Diff
426
src/transformers/models/rt_detr/modeling_rt_detr_resnet.py
Normal file
426
src/transformers/models/rt_detr/modeling_rt_detr_resnet.py
Normal file
@ -0,0 +1,426 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
PyTorch RTDetr specific ResNet model. The main difference between hugginface ResNet model is that this RTDetrResNet model forces to use shortcut at the first layer in the resnet-18/34 models.
|
||||
See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L126 for details.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import (
|
||||
BackboneOutput,
|
||||
BaseModelOutputWithNoAttention,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_rt_detr_resnet import RTDetrResNetConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "RTDetrResNetConfig"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "microsoft/resnet-50"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7]
|
||||
|
||||
|
||||
# Copied from transformers.models.resnet.modeling_resnet.ResNetConvLayer -> RTDetrResNetConvLayer
|
||||
class RTDetrResNetConvLayer(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
|
||||
):
|
||||
super().__init__()
|
||||
self.convolution = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False
|
||||
)
|
||||
self.normalization = nn.BatchNorm2d(out_channels)
|
||||
self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
hidden_state = self.convolution(input)
|
||||
hidden_state = self.normalization(hidden_state)
|
||||
hidden_state = self.activation(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class RTDetrResNetEmbeddings(nn.Module):
|
||||
"""
|
||||
ResNet Embeddings (stem) composed of a deep aggressive convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, config: RTDetrResNetConfig):
|
||||
super().__init__()
|
||||
self.embedder = nn.Sequential(
|
||||
*[
|
||||
RTDetrResNetConvLayer(
|
||||
config.num_channels,
|
||||
config.embedding_size // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
activation=config.hidden_act,
|
||||
),
|
||||
RTDetrResNetConvLayer(
|
||||
config.embedding_size // 2,
|
||||
config.embedding_size // 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
activation=config.hidden_act,
|
||||
),
|
||||
RTDetrResNetConvLayer(
|
||||
config.embedding_size // 2,
|
||||
config.embedding_size,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
activation=config.hidden_act,
|
||||
),
|
||||
]
|
||||
)
|
||||
self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.num_channels = config.num_channels
|
||||
|
||||
def forward(self, pixel_values: Tensor) -> Tensor:
|
||||
num_channels = pixel_values.shape[1]
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||
)
|
||||
embedding = self.embedder(pixel_values)
|
||||
embedding = self.pooler(embedding)
|
||||
return embedding
|
||||
|
||||
|
||||
# Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut -> RTDetrResNetChortCut
|
||||
class RTDetrResNetShortCut(nn.Module):
|
||||
"""
|
||||
ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
|
||||
downsample the input using `stride=2`.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
|
||||
super().__init__()
|
||||
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
|
||||
self.normalization = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
hidden_state = self.convolution(input)
|
||||
hidden_state = self.normalization(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class RTDetrResNetBasicLayer(nn.Module):
|
||||
"""
|
||||
A classic ResNet's residual layer composed by two `3x3` convolutions.
|
||||
See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L34.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RTDetrResNetConfig,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int = 1,
|
||||
should_apply_shortcut: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = (
|
||||
nn.Sequential(
|
||||
*[nn.AvgPool2d(2, 2, 0, ceil_mode=True), RTDetrResNetShortCut(in_channels, out_channels, stride=1)]
|
||||
)
|
||||
if should_apply_shortcut
|
||||
else nn.Identity()
|
||||
)
|
||||
else:
|
||||
self.shortcut = (
|
||||
RTDetrResNetShortCut(in_channels, out_channels, stride=stride)
|
||||
if should_apply_shortcut
|
||||
else nn.Identity()
|
||||
)
|
||||
self.layer = nn.Sequential(
|
||||
RTDetrResNetConvLayer(in_channels, out_channels, stride=stride),
|
||||
RTDetrResNetConvLayer(out_channels, out_channels, activation=None),
|
||||
)
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_state):
|
||||
residual = hidden_state
|
||||
hidden_state = self.layer(hidden_state)
|
||||
residual = self.shortcut(residual)
|
||||
hidden_state += residual
|
||||
hidden_state = self.activation(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class RTDetrResNetBottleNeckLayer(nn.Module):
|
||||
"""
|
||||
A classic RTDetrResNet's bottleneck layer composed by three `3x3` convolutions.
|
||||
|
||||
The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
|
||||
convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If
|
||||
`downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RTDetrResNetConfig,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
reduction = 4
|
||||
should_apply_shortcut = in_channels != out_channels or stride != 1
|
||||
reduces_channels = out_channels // reduction
|
||||
if stride == 2:
|
||||
self.shortcut = nn.Sequential(
|
||||
*[
|
||||
nn.AvgPool2d(2, 2, 0, ceil_mode=True),
|
||||
RTDetrResNetShortCut(in_channels, out_channels, stride=1)
|
||||
if should_apply_shortcut
|
||||
else nn.Identity(),
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.shortcut = (
|
||||
RTDetrResNetShortCut(in_channels, out_channels, stride=stride)
|
||||
if should_apply_shortcut
|
||||
else nn.Identity()
|
||||
)
|
||||
self.layer = nn.Sequential(
|
||||
RTDetrResNetConvLayer(
|
||||
in_channels, reduces_channels, kernel_size=1, stride=stride if config.downsample_in_bottleneck else 1
|
||||
),
|
||||
RTDetrResNetConvLayer(
|
||||
reduces_channels, reduces_channels, stride=stride if not config.downsample_in_bottleneck else 1
|
||||
),
|
||||
RTDetrResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None),
|
||||
)
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_state):
|
||||
residual = hidden_state
|
||||
hidden_state = self.layer(hidden_state)
|
||||
residual = self.shortcut(residual)
|
||||
hidden_state += residual
|
||||
hidden_state = self.activation(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class RTDetrResNetStage(nn.Module):
|
||||
"""
|
||||
A RTDetrResNet stage composed by stacked layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RTDetrResNetConfig,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: int = 2,
|
||||
depth: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
layer = RTDetrResNetBottleNeckLayer if config.layer_type == "bottleneck" else RTDetrResNetBasicLayer
|
||||
|
||||
if config.layer_type == "bottleneck":
|
||||
first_layer = layer(
|
||||
config,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=stride,
|
||||
)
|
||||
else:
|
||||
first_layer = layer(config, in_channels, out_channels, stride=stride, should_apply_shortcut=True)
|
||||
self.layers = nn.Sequential(
|
||||
first_layer, *[layer(config, out_channels, out_channels) for _ in range(depth - 1)]
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
hidden_state = input
|
||||
for layer in self.layers:
|
||||
hidden_state = layer(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
# Copied from transformers.models.resnet.modeling_resnet.ResNetEncoder with ResNet->RTDetrResNet
|
||||
class RTDetrResNetEncoder(nn.Module):
|
||||
def __init__(self, config: RTDetrResNetConfig):
|
||||
super().__init__()
|
||||
self.stages = nn.ModuleList([])
|
||||
# based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
|
||||
self.stages.append(
|
||||
RTDetrResNetStage(
|
||||
config,
|
||||
config.embedding_size,
|
||||
config.hidden_sizes[0],
|
||||
stride=2 if config.downsample_in_first_stage else 1,
|
||||
depth=config.depths[0],
|
||||
)
|
||||
)
|
||||
in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
|
||||
for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
|
||||
self.stages.append(RTDetrResNetStage(config, in_channels, out_channels, depth=depth))
|
||||
|
||||
def forward(
|
||||
self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
|
||||
) -> BaseModelOutputWithNoAttention:
|
||||
hidden_states = () if output_hidden_states else None
|
||||
|
||||
for stage_module in self.stages:
|
||||
if output_hidden_states:
|
||||
hidden_states = hidden_states + (hidden_state,)
|
||||
|
||||
hidden_state = stage_module(hidden_state)
|
||||
|
||||
if output_hidden_states:
|
||||
hidden_states = hidden_states + (hidden_state,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
|
||||
|
||||
return BaseModelOutputWithNoAttention(
|
||||
last_hidden_state=hidden_state,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RTDetrResNet
|
||||
class RTDetrResNetPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = RTDetrResNetConfig
|
||||
base_model_prefix = "resnet"
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["RTDetrResNetConvLayer", "RTDetrResNetShortCut"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(module.weight, 1)
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
|
||||
RTDETR_RESNET_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`RTDetrResNetConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
RTDETR_RESNET_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
||||
[`RTDetrImageProcessor.__call__`] for details.
|
||||
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
ResNet backbone, to be used with frameworks like RTDETR.
|
||||
""",
|
||||
RTDETR_RESNET_START_DOCSTRING,
|
||||
)
|
||||
class RTDetrResNetBackbone(RTDetrResNetPreTrainedModel, BackboneMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
super()._init_backbone(config)
|
||||
|
||||
self.num_features = [config.embedding_size] + config.hidden_sizes
|
||||
self.embedder = RTDetrResNetEmbeddings(config)
|
||||
self.encoder = RTDetrResNetEncoder(config)
|
||||
|
||||
# initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(RTDETR_RESNET_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
|
||||
) -> BackboneOutput:
|
||||
"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import RTDetrResNetConfig, RTDetrResNetBackbone
|
||||
>>> import torch
|
||||
|
||||
>>> config = RTDetrResNetConfig()
|
||||
>>> model = RTDetrResNetBackbone(config)
|
||||
|
||||
>>> pixel_values = torch.randn(1, 3, 224, 224)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(pixel_values)
|
||||
|
||||
>>> feature_maps = outputs.feature_maps
|
||||
>>> list(feature_maps[-1].shape)
|
||||
[1, 2048, 7, 7]
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
embedding_output = self.embedder(pixel_values)
|
||||
|
||||
outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
|
||||
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
feature_maps = ()
|
||||
for idx, stage in enumerate(self.stage_names):
|
||||
if stage in self.out_features:
|
||||
feature_maps += (hidden_states[idx],)
|
||||
|
||||
if not return_dict:
|
||||
output = (feature_maps,)
|
||||
if output_hidden_states:
|
||||
output += (outputs.hidden_states,)
|
||||
return output
|
||||
|
||||
return BackboneOutput(
|
||||
feature_maps=feature_maps,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=None,
|
||||
)
|
@ -113,10 +113,10 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
|
||||
return super()._from_config(config, **kwargs)
|
||||
|
||||
def freeze_batch_norm_2d(self):
|
||||
timm.layers.freeze_batch_norm_2d(self._backbone)
|
||||
timm.utils.model.freeze_batch_norm_2d(self._backbone)
|
||||
|
||||
def unfreeze_batch_norm_2d(self):
|
||||
timm.layers.unfreeze_batch_norm_2d(self._backbone)
|
||||
timm.utils.model.unfreeze_batch_norm_2d(self._backbone)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""
|
||||
|
@ -313,7 +313,6 @@ def load_backbone(config):
|
||||
use_pretrained_backbone = getattr(config, "use_pretrained_backbone", None)
|
||||
backbone_checkpoint = getattr(config, "backbone", None)
|
||||
backbone_kwargs = getattr(config, "backbone_kwargs", None)
|
||||
|
||||
backbone_kwargs = {} if backbone_kwargs is None else backbone_kwargs
|
||||
|
||||
if backbone_kwargs and backbone_config is not None:
|
||||
|
@ -7492,6 +7492,41 @@ def load_tf_weights_in_roformer(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_roformer, ["torch"])
|
||||
|
||||
|
||||
class RTDetrForObjectDetection(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RTDetrModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RTDetrPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RTDetrResNetBackbone(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RTDetrResNetPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RwkvForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -492,6 +492,13 @@ class PvtImageProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class RTDetrImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class SamImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
|
0
tests/models/rt_detr/__init__.py
Normal file
0
tests/models/rt_detr/__init__.py
Normal file
364
tests/models/rt_detr/test_image_processing_rt_detr.py
Normal file
364
tests/models/rt_detr/test_image_processing_rt_detr.py
Normal file
@ -0,0 +1,364 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision, slow
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import RTDetrImageProcessor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class RTDetrImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=4,
|
||||
num_channels=3,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_rescale=True,
|
||||
rescale_factor=1 / 255,
|
||||
do_normalize=False,
|
||||
do_pad=False,
|
||||
return_tensors="pt",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.do_resize = do_resize
|
||||
self.size = size if size is not None else {"height": 640, "width": 640}
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.do_pad = do_pad
|
||||
self.return_tensors = return_tensors
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_rescale": self.do_rescale,
|
||||
"rescale_factor": self.rescale_factor,
|
||||
"do_normalize": self.do_normalize,
|
||||
"do_pad": self.do_pad,
|
||||
"return_tensors": self.return_tensors,
|
||||
}
|
||||
|
||||
def get_expected_values(self):
|
||||
return self.size["height"], self.size["width"]
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
height, width = self.get_expected_values()
|
||||
return self.num_channels, height, width
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class RtDetrImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = RTDetrImageProcessor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = RTDetrImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "resample"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "return_tensors"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 640, "width": 640})
|
||||
|
||||
def test_valid_coco_detection_annotations(self):
|
||||
# prepare image and target
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f:
|
||||
target = json.loads(f.read())
|
||||
|
||||
params = {"image_id": 39769, "annotations": target}
|
||||
|
||||
# encode them
|
||||
image_processing = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
|
||||
|
||||
# legal encodings (single image)
|
||||
_ = image_processing(images=image, annotations=params, return_tensors="pt")
|
||||
_ = image_processing(images=image, annotations=[params], return_tensors="pt")
|
||||
|
||||
# legal encodings (batch of one image)
|
||||
_ = image_processing(images=[image], annotations=params, return_tensors="pt")
|
||||
_ = image_processing(images=[image], annotations=[params], return_tensors="pt")
|
||||
|
||||
# legal encoding (batch of more than one image)
|
||||
n = 5
|
||||
_ = image_processing(images=[image] * n, annotations=[params] * n, return_tensors="pt")
|
||||
|
||||
# example of an illegal encoding (missing the 'image_id' key)
|
||||
with self.assertRaises(ValueError) as e:
|
||||
image_processing(images=image, annotations={"annotations": target}, return_tensors="pt")
|
||||
|
||||
self.assertTrue(str(e.exception).startswith("Invalid COCO detection annotations"))
|
||||
|
||||
# example of an illegal encoding (unequal lengths of images and annotations)
|
||||
with self.assertRaises(ValueError) as e:
|
||||
image_processing(images=[image] * n, annotations=[params] * (n - 1), return_tensors="pt")
|
||||
|
||||
self.assertTrue(str(e.exception) == "The number of images (5) and annotations (4) do not match.")
|
||||
|
||||
@slow
|
||||
def test_call_pytorch_with_coco_detection_annotations(self):
|
||||
# prepare image and target
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f:
|
||||
target = json.loads(f.read())
|
||||
|
||||
target = {"image_id": 39769, "annotations": target}
|
||||
|
||||
# encode them
|
||||
image_processing = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
|
||||
encoding = image_processing(images=image, annotations=target, return_tensors="pt")
|
||||
|
||||
# verify pixel values
|
||||
expected_shape = torch.Size([1, 3, 640, 640])
|
||||
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([0.5490, 0.5647, 0.5725])
|
||||
self.assertTrue(torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
# verify area
|
||||
expected_area = torch.tensor([2827.9883, 5403.4761, 235036.7344, 402070.2188, 71068.8281, 79601.2812])
|
||||
self.assertTrue(torch.allclose(encoding["labels"][0]["area"], expected_area))
|
||||
# verify boxes
|
||||
expected_boxes_shape = torch.Size([6, 4])
|
||||
self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
|
||||
expected_boxes_slice = torch.tensor([0.5503, 0.2765, 0.0604, 0.2215])
|
||||
self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"][0], expected_boxes_slice, atol=1e-3))
|
||||
# verify image_id
|
||||
expected_image_id = torch.tensor([39769])
|
||||
self.assertTrue(torch.allclose(encoding["labels"][0]["image_id"], expected_image_id))
|
||||
# verify is_crowd
|
||||
expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
|
||||
self.assertTrue(torch.allclose(encoding["labels"][0]["iscrowd"], expected_is_crowd))
|
||||
# verify class_labels
|
||||
expected_class_labels = torch.tensor([75, 75, 63, 65, 17, 17])
|
||||
self.assertTrue(torch.allclose(encoding["labels"][0]["class_labels"], expected_class_labels))
|
||||
# verify orig_size
|
||||
expected_orig_size = torch.tensor([480, 640])
|
||||
self.assertTrue(torch.allclose(encoding["labels"][0]["orig_size"], expected_orig_size))
|
||||
# verify size
|
||||
expected_size = torch.tensor([640, 640])
|
||||
self.assertTrue(torch.allclose(encoding["labels"][0]["size"], expected_size))
|
||||
|
||||
@slow
|
||||
def test_image_processor_outputs(self):
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
encoding = image_processing(images=image, return_tensors="pt")
|
||||
|
||||
# verify pixel values: shape
|
||||
expected_shape = torch.Size([1, 3, 640, 640])
|
||||
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
|
||||
|
||||
# verify pixel values: output values
|
||||
expected_slice = torch.tensor([0.5490196347236633, 0.5647059082984924, 0.572549045085907])
|
||||
self.assertTrue(torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-5))
|
||||
|
||||
def test_multiple_images_processor_outputs(self):
|
||||
images_urls = [
|
||||
"http://images.cocodataset.org/val2017/000000000139.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000285.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000632.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000724.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000776.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000785.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000802.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000872.jpg",
|
||||
]
|
||||
|
||||
images = []
|
||||
for url in images_urls:
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
images.append(image)
|
||||
|
||||
# apply image processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
encoding = image_processing(images=images, return_tensors="pt")
|
||||
|
||||
# verify if pixel_values is part of the encoding
|
||||
self.assertIn("pixel_values", encoding)
|
||||
|
||||
# verify pixel values: shape
|
||||
expected_shape = torch.Size([8, 3, 640, 640])
|
||||
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
|
||||
|
||||
# verify pixel values: output values
|
||||
expected_slices = torch.tensor(
|
||||
[
|
||||
[0.5333333611488342, 0.5568627715110779, 0.5647059082984924],
|
||||
[0.5372549295425415, 0.4705882668495178, 0.4274510145187378],
|
||||
[0.3960784673690796, 0.35686275362968445, 0.3686274588108063],
|
||||
[0.20784315466880798, 0.1882353127002716, 0.15294118225574493],
|
||||
[0.364705890417099, 0.364705890417099, 0.3686274588108063],
|
||||
[0.8078432083129883, 0.8078432083129883, 0.8078432083129883],
|
||||
[0.4431372880935669, 0.4431372880935669, 0.4431372880935669],
|
||||
[0.19607844948768616, 0.21176472306251526, 0.3607843220233917],
|
||||
]
|
||||
)
|
||||
self.assertTrue(torch.allclose(encoding["pixel_values"][:, 1, 0, :3], expected_slices, atol=1e-5))
|
||||
|
||||
@slow
|
||||
def test_batched_coco_detection_annotations(self):
|
||||
image_0 = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
image_1 = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png").resize((800, 800))
|
||||
|
||||
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f:
|
||||
target = json.loads(f.read())
|
||||
|
||||
annotations_0 = {"image_id": 39769, "annotations": target}
|
||||
annotations_1 = {"image_id": 39769, "annotations": target}
|
||||
|
||||
# Adjust the bounding boxes for the resized image
|
||||
w_0, h_0 = image_0.size
|
||||
w_1, h_1 = image_1.size
|
||||
for i in range(len(annotations_1["annotations"])):
|
||||
coords = annotations_1["annotations"][i]["bbox"]
|
||||
new_bbox = [
|
||||
coords[0] * w_1 / w_0,
|
||||
coords[1] * h_1 / h_0,
|
||||
coords[2] * w_1 / w_0,
|
||||
coords[3] * h_1 / h_0,
|
||||
]
|
||||
annotations_1["annotations"][i]["bbox"] = new_bbox
|
||||
|
||||
images = [image_0, image_1]
|
||||
annotations = [annotations_0, annotations_1]
|
||||
|
||||
image_processing = RTDetrImageProcessor()
|
||||
encoding = image_processing(
|
||||
images=images,
|
||||
annotations=annotations,
|
||||
return_segmentation_masks=True,
|
||||
return_tensors="pt", # do_convert_annotations=True
|
||||
)
|
||||
|
||||
# Check the pixel values have been padded
|
||||
postprocessed_height, postprocessed_width = 640, 640
|
||||
expected_shape = torch.Size([2, 3, postprocessed_height, postprocessed_width])
|
||||
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
|
||||
|
||||
# Check the bounding boxes have been adjusted for padded images
|
||||
self.assertEqual(encoding["labels"][0]["boxes"].shape, torch.Size([6, 4]))
|
||||
self.assertEqual(encoding["labels"][1]["boxes"].shape, torch.Size([6, 4]))
|
||||
expected_boxes_0 = torch.tensor(
|
||||
[
|
||||
[0.6879, 0.4609, 0.0755, 0.3691],
|
||||
[0.2118, 0.3359, 0.2601, 0.1566],
|
||||
[0.5011, 0.5000, 0.9979, 1.0000],
|
||||
[0.5010, 0.5020, 0.9979, 0.9959],
|
||||
[0.3284, 0.5944, 0.5884, 0.8112],
|
||||
[0.8394, 0.5445, 0.3213, 0.9110],
|
||||
]
|
||||
)
|
||||
expected_boxes_1 = torch.tensor(
|
||||
[
|
||||
[0.5503, 0.2765, 0.0604, 0.2215],
|
||||
[0.1695, 0.2016, 0.2080, 0.0940],
|
||||
[0.5006, 0.4933, 0.9977, 0.9865],
|
||||
[0.5008, 0.5002, 0.9983, 0.9955],
|
||||
[0.2627, 0.5456, 0.4707, 0.8646],
|
||||
[0.7715, 0.4115, 0.4570, 0.7161],
|
||||
]
|
||||
)
|
||||
self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"], expected_boxes_0, rtol=1e-3))
|
||||
self.assertTrue(torch.allclose(encoding["labels"][1]["boxes"], expected_boxes_1, rtol=1e-3))
|
||||
|
||||
# Check if do_convert_annotations=False, then the annotations are not converted to centre_x, centre_y, width, height
|
||||
# format and not in the range [0, 1]
|
||||
encoding = image_processing(
|
||||
images=images,
|
||||
annotations=annotations,
|
||||
return_segmentation_masks=True,
|
||||
do_convert_annotations=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertEqual(encoding["labels"][0]["boxes"].shape, torch.Size([6, 4]))
|
||||
self.assertEqual(encoding["labels"][1]["boxes"].shape, torch.Size([6, 4]))
|
||||
# Convert to absolute coordinates
|
||||
unnormalized_boxes_0 = torch.vstack(
|
||||
[
|
||||
expected_boxes_0[:, 0] * postprocessed_width,
|
||||
expected_boxes_0[:, 1] * postprocessed_height,
|
||||
expected_boxes_0[:, 2] * postprocessed_width,
|
||||
expected_boxes_0[:, 3] * postprocessed_height,
|
||||
]
|
||||
).T
|
||||
unnormalized_boxes_1 = torch.vstack(
|
||||
[
|
||||
expected_boxes_1[:, 0] * postprocessed_width,
|
||||
expected_boxes_1[:, 1] * postprocessed_height,
|
||||
expected_boxes_1[:, 2] * postprocessed_width,
|
||||
expected_boxes_1[:, 3] * postprocessed_height,
|
||||
]
|
||||
).T
|
||||
# Convert from centre_x, centre_y, width, height to x_min, y_min, x_max, y_max
|
||||
expected_boxes_0 = torch.vstack(
|
||||
[
|
||||
unnormalized_boxes_0[:, 0] - unnormalized_boxes_0[:, 2] / 2,
|
||||
unnormalized_boxes_0[:, 1] - unnormalized_boxes_0[:, 3] / 2,
|
||||
unnormalized_boxes_0[:, 0] + unnormalized_boxes_0[:, 2] / 2,
|
||||
unnormalized_boxes_0[:, 1] + unnormalized_boxes_0[:, 3] / 2,
|
||||
]
|
||||
).T
|
||||
expected_boxes_1 = torch.vstack(
|
||||
[
|
||||
unnormalized_boxes_1[:, 0] - unnormalized_boxes_1[:, 2] / 2,
|
||||
unnormalized_boxes_1[:, 1] - unnormalized_boxes_1[:, 3] / 2,
|
||||
unnormalized_boxes_1[:, 0] + unnormalized_boxes_1[:, 2] / 2,
|
||||
unnormalized_boxes_1[:, 1] + unnormalized_boxes_1[:, 3] / 2,
|
||||
]
|
||||
).T
|
||||
self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"], expected_boxes_0, rtol=1))
|
||||
self.assertTrue(torch.allclose(encoding["labels"][1]["boxes"], expected_boxes_1, rtol=1))
|
680
tests/models/rt_detr/test_modeling_rt_detr.py
Normal file
680
tests/models/rt_detr/test_modeling_rt_detr.py
Normal file
@ -0,0 +1,680 @@
|
||||
# coding = utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch RT_DETR model."""
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
RTDetrConfig,
|
||||
RTDetrImageProcessor,
|
||||
RTDetrResNetConfig,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import require_torch, require_vision, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import RTDetrForObjectDetection, RTDetrModel
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
CHECKPOINT = "PekingU/rtdetr_r50vd" # TODO: replace
|
||||
|
||||
|
||||
class RTDetrModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=3,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
n_targets=3,
|
||||
num_labels=10,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
batch_norm_eps=1e-5,
|
||||
# backbone
|
||||
backbone_config=None,
|
||||
# encoder HybridEncoder
|
||||
encoder_hidden_dim=32,
|
||||
encoder_in_channels=[128, 256, 512],
|
||||
feat_strides=[8, 16, 32],
|
||||
encoder_layers=1,
|
||||
encoder_ffn_dim=64,
|
||||
encoder_attention_heads=2,
|
||||
dropout=0.0,
|
||||
activation_dropout=0.0,
|
||||
encode_proj_layers=[2],
|
||||
positional_encoding_temperature=10000,
|
||||
encoder_activation_function="gelu",
|
||||
activation_function="silu",
|
||||
eval_size=None,
|
||||
normalize_before=False,
|
||||
# decoder RTDetrTransformer
|
||||
d_model=32,
|
||||
num_queries=30,
|
||||
decoder_in_channels=[32, 32, 32],
|
||||
decoder_ffn_dim=64,
|
||||
num_feature_levels=3,
|
||||
decoder_n_points=4,
|
||||
decoder_layers=2,
|
||||
decoder_attention_heads=2,
|
||||
decoder_activation_function="relu",
|
||||
attention_dropout=0.0,
|
||||
num_denoising=0,
|
||||
label_noise_ratio=0.5,
|
||||
box_noise_scale=1.0,
|
||||
learn_initial_query=False,
|
||||
anchor_image_size=[64, 64],
|
||||
image_size=64,
|
||||
disable_custom_kernels=True,
|
||||
with_box_refine=True,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = 3
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.n_targets = n_targets
|
||||
self.num_labels = num_labels
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.batch_norm_eps = batch_norm_eps
|
||||
self.backbone_config = backbone_config
|
||||
self.encoder_hidden_dim = encoder_hidden_dim
|
||||
self.encoder_in_channels = encoder_in_channels
|
||||
self.feat_strides = feat_strides
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.dropout = dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.encode_proj_layers = encode_proj_layers
|
||||
self.positional_encoding_temperature = positional_encoding_temperature
|
||||
self.encoder_activation_function = encoder_activation_function
|
||||
self.activation_function = activation_function
|
||||
self.eval_size = eval_size
|
||||
self.normalize_before = normalize_before
|
||||
self.d_model = d_model
|
||||
self.num_queries = num_queries
|
||||
self.decoder_in_channels = decoder_in_channels
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.num_feature_levels = num_feature_levels
|
||||
self.decoder_n_points = decoder_n_points
|
||||
self.decoder_layers = decoder_layers
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.decoder_activation_function = decoder_activation_function
|
||||
self.attention_dropout = attention_dropout
|
||||
self.num_denoising = num_denoising
|
||||
self.label_noise_ratio = label_noise_ratio
|
||||
self.box_noise_scale = box_noise_scale
|
||||
self.learn_initial_query = learn_initial_query
|
||||
self.anchor_image_size = anchor_image_size
|
||||
self.image_size = image_size
|
||||
self.disable_custom_kernels = disable_custom_kernels
|
||||
self.with_box_refine = with_box_refine
|
||||
|
||||
self.encoder_seq_length = math.ceil(self.image_size / 32) * math.ceil(self.image_size / 32)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device)
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
# labels is a list of Dict (each Dict being the labels for a given example in the batch)
|
||||
labels = []
|
||||
for i in range(self.batch_size):
|
||||
target = {}
|
||||
target["class_labels"] = torch.randint(
|
||||
high=self.num_labels, size=(self.n_targets,), device=torch_device
|
||||
)
|
||||
target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device)
|
||||
labels.append(target)
|
||||
|
||||
config = self.get_config()
|
||||
config.num_labels = self.num_labels
|
||||
return config, pixel_values, pixel_mask, labels
|
||||
|
||||
def get_config(self):
|
||||
hidden_sizes = [10, 20, 30, 40]
|
||||
backbone_config = RTDetrResNetConfig(
|
||||
embeddings_size=10,
|
||||
hidden_sizes=hidden_sizes,
|
||||
depths=[1, 1, 2, 1],
|
||||
out_features=["stage2", "stage3", "stage4"],
|
||||
out_indices=[2, 3, 4],
|
||||
)
|
||||
return RTDetrConfig.from_backbone_configs(
|
||||
backbone_config=backbone_config,
|
||||
encoder_hidden_dim=self.encoder_hidden_dim,
|
||||
encoder_in_channels=hidden_sizes[1:],
|
||||
feat_strides=self.feat_strides,
|
||||
encoder_layers=self.encoder_layers,
|
||||
encoder_ffn_dim=self.encoder_ffn_dim,
|
||||
encoder_attention_heads=self.encoder_attention_heads,
|
||||
dropout=self.dropout,
|
||||
activation_dropout=self.activation_dropout,
|
||||
encode_proj_layers=self.encode_proj_layers,
|
||||
positional_encoding_temperature=self.positional_encoding_temperature,
|
||||
encoder_activation_function=self.encoder_activation_function,
|
||||
activation_function=self.activation_function,
|
||||
eval_size=self.eval_size,
|
||||
normalize_before=self.normalize_before,
|
||||
d_model=self.d_model,
|
||||
num_queries=self.num_queries,
|
||||
decoder_in_channels=self.decoder_in_channels,
|
||||
decoder_ffn_dim=self.decoder_ffn_dim,
|
||||
num_feature_levels=self.num_feature_levels,
|
||||
decoder_n_points=self.decoder_n_points,
|
||||
decoder_layers=self.decoder_layers,
|
||||
decoder_attention_heads=self.decoder_attention_heads,
|
||||
decoder_activation_function=self.decoder_activation_function,
|
||||
attention_dropout=self.attention_dropout,
|
||||
num_denoising=self.num_denoising,
|
||||
label_noise_ratio=self.label_noise_ratio,
|
||||
box_noise_scale=self.box_noise_scale,
|
||||
learn_initial_query=self.learn_initial_query,
|
||||
anchor_image_size=self.anchor_image_size,
|
||||
image_size=self.image_size,
|
||||
disable_custom_kernels=self.disable_custom_kernels,
|
||||
with_box_refine=self.with_box_refine,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs()
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_rt_detr_model(self, config, pixel_values, pixel_mask, labels):
|
||||
model = RTDetrModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
result = model(pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model))
|
||||
|
||||
def create_and_check_rt_detr_object_detection_head_model(self, config, pixel_values, pixel_mask, labels):
|
||||
model = RTDetrForObjectDetection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
result = model(pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
|
||||
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
|
||||
|
||||
|
||||
@require_torch
|
||||
class RTDetrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (RTDetrModel, RTDetrForObjectDetection) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"image-feature-extraction": RTDetrModel, "object-detection": RTDetrForObjectDetection}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
is_encoder_decoder = True
|
||||
test_torchscript = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
|
||||
# special case for head models
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class.__name__ == "RTDetrForObjectDetection":
|
||||
labels = []
|
||||
for i in range(self.model_tester.batch_size):
|
||||
target = {}
|
||||
target["class_labels"] = torch.ones(
|
||||
size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long
|
||||
)
|
||||
target["boxes"] = torch.ones(
|
||||
self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float
|
||||
)
|
||||
labels.append(target)
|
||||
inputs_dict["labels"] = labels
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RTDetrModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self,
|
||||
config_class=RTDetrConfig,
|
||||
has_text_modality=False,
|
||||
common_properties=["hidden_size", "num_attention_heads"],
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_rt_detr_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_rt_detr_model(*config_and_inputs)
|
||||
|
||||
def test_rt_detr_object_detection_head_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_rt_detr_object_detection_head_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="RTDetr does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="RTDetr does not use test_inputs_embeds_matches_input_ids")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="RTDetr does not support input and output embeddings")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="RTDetr does not support input and output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="RTDetr does not use token embeddings")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Feed forward chunking is not implemented")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.encoder_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.encoder_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.encoder_attention_heads,
|
||||
self.model_tester.encoder_seq_length,
|
||||
self.model_tester.encoder_seq_length,
|
||||
],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
correct_outlen = 13
|
||||
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
# Object Detection model returns pred_logits and pred_boxes
|
||||
if model_class.__name__ == "RTDetrForObjectDetection":
|
||||
correct_outlen += 2
|
||||
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
# decoder attentions
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.decoder_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.decoder_attention_heads,
|
||||
self.model_tester.num_queries,
|
||||
self.model_tester.num_queries,
|
||||
],
|
||||
)
|
||||
|
||||
# cross attentions
|
||||
cross_attentions = outputs.cross_attentions
|
||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.decoder_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.decoder_attention_heads,
|
||||
self.model_tester.num_feature_levels,
|
||||
self.model_tester.decoder_n_points,
|
||||
],
|
||||
)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
else:
|
||||
# RTDetr should maintin encoder_hidden_states output
|
||||
added_hidden_states = 2
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.encoder_layers)
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.encoder_attention_heads,
|
||||
self.model_tester.encoder_seq_length,
|
||||
self.model_tester.encoder_seq_length,
|
||||
],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.encoder_in_channels) - 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[1].shape[-2:]),
|
||||
[
|
||||
self.model_tester.image_size // self.model_tester.feat_strides[-1],
|
||||
self.model_tester.image_size // self.model_tester.feat_strides[-1],
|
||||
],
|
||||
)
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
hidden_states = outputs.decoder_hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.decoder_layers + 1
|
||||
)
|
||||
|
||||
self.assertIsInstance(hidden_states, (list, tuple))
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.num_queries, self.model_tester.d_model],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
# we take the first output since last_hidden_state is the first item
|
||||
output = outputs[0]
|
||||
|
||||
encoder_hidden_states = outputs.encoder_hidden_states[0]
|
||||
encoder_attentions = outputs.encoder_attentions[0]
|
||||
encoder_hidden_states.retain_grad()
|
||||
encoder_attentions.retain_grad()
|
||||
|
||||
decoder_attentions = outputs.decoder_attentions[0]
|
||||
decoder_attentions.retain_grad()
|
||||
|
||||
cross_attentions = outputs.cross_attentions[0]
|
||||
cross_attentions.retain_grad()
|
||||
|
||||
output.flatten()[0].backward(retain_graph=True)
|
||||
|
||||
self.assertIsNotNone(encoder_hidden_states.grad)
|
||||
self.assertIsNotNone(encoder_attentions.grad)
|
||||
self.assertIsNotNone(decoder_attentions.grad)
|
||||
self.assertIsNotNone(cross_attentions.grad)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_different_timm_backbone(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# let's pick a random timm backbone
|
||||
config.backbone = "tf_mobilenetv3_small_075"
|
||||
config.backbone_config = None
|
||||
config.use_timm_backbone = True
|
||||
config.backbone_kwargs = {"out_indices": [2, 3, 4]}
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if model_class.__name__ == "RTDetrForObjectDetection":
|
||||
expected_shape = (
|
||||
self.model_tester.batch_size,
|
||||
self.model_tester.num_queries,
|
||||
self.model_tester.num_labels,
|
||||
)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.model.backbone.intermediate_channel_sizes), 3)
|
||||
else:
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.backbone.intermediate_channel_sizes), 3)
|
||||
|
||||
self.assertTrue(outputs)
|
||||
|
||||
def test_hf_backbone(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Load a pretrained HF checkpoint as backbone
|
||||
config.backbone = "microsoft/resnet-18"
|
||||
config.backbone_config = None
|
||||
config.use_timm_backbone = False
|
||||
config.use_pretrained_backbone = True
|
||||
config.backbone_kwargs = {"out_indices": [2, 3, 4]}
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if model_class.__name__ == "RTDetrForObjectDetection":
|
||||
expected_shape = (
|
||||
self.model_tester.batch_size,
|
||||
self.model_tester.num_queries,
|
||||
self.model_tester.num_labels,
|
||||
)
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.model.backbone.intermediate_channel_sizes), 3)
|
||||
else:
|
||||
# Confirm out_indices was propogated to backbone
|
||||
self.assertEqual(len(model.backbone.intermediate_channel_sizes), 3)
|
||||
|
||||
self.assertTrue(outputs)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
# Skip the check for the backbone
|
||||
for name, module in model.named_modules():
|
||||
if module.__class__.__name__ == "RTDetrConvEncoder":
|
||||
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
|
||||
break
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
if (
|
||||
"level_embed" in name
|
||||
or "sampling_offsets.bias" in name
|
||||
or "value_proj" in name
|
||||
or "output_proj" in name
|
||||
or "reference_points" in name
|
||||
or name in backbone_params
|
||||
):
|
||||
continue
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class RTDetrModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_image_processor(self):
|
||||
return RTDetrImageProcessor.from_pretrained(CHECKPOINT) if is_vision_available() else None
|
||||
|
||||
def test_inference_object_detection_head(self):
|
||||
model = RTDetrForObjectDetection.from_pretrained(CHECKPOINT).to(torch_device)
|
||||
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
expected_shape_logits = torch.Size((1, 300, model.config.num_labels))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape_logits)
|
||||
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[-4.64763879776001, -5.001153945922852, -4.978509902954102],
|
||||
[-4.159348487854004, -4.703853607177734, -5.946484565734863],
|
||||
[-4.437461853027344, -4.65836238861084, -6.235235691070557],
|
||||
]
|
||||
).to(torch_device)
|
||||
expected_boxes = torch.tensor(
|
||||
[
|
||||
[0.1688060760498047, 0.19992263615131378, 0.21225441992282867],
|
||||
[0.768376350402832, 0.41226309537887573, 0.4636859893798828],
|
||||
[0.25953856110572815, 0.5483334064483643, 0.4777486026287079],
|
||||
]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_logits, atol=1e-4))
|
||||
|
||||
expected_shape_boxes = torch.Size((1, 300, 4))
|
||||
self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=1e-4))
|
||||
|
||||
# verify postprocessing
|
||||
results = image_processor.post_process_object_detection(
|
||||
outputs, threshold=0.0, target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
expected_scores = torch.tensor(
|
||||
[0.9703017473220825, 0.9599503874778748, 0.9575679302215576, 0.9506784677505493], device=torch_device
|
||||
)
|
||||
expected_labels = [57, 15, 15, 65]
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[
|
||||
[0.13774872, 0.37821293, 640.13074, 476.21088],
|
||||
[343.38132, 24.276838, 640.1404, 371.49573],
|
||||
[13.225126, 54.179348, 318.98422, 472.2207],
|
||||
[40.114475, 73.44104, 175.9573, 118.48469],
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(results["scores"][:4], expected_scores, atol=1e-4))
|
||||
self.assertSequenceEqual(results["labels"][:4].tolist(), expected_labels)
|
||||
self.assertTrue(torch.allclose(results["boxes"][:4], expected_slice_boxes, atol=1e-4))
|
130
tests/models/rt_detr/test_modeling_rt_detr_resnet.py
Normal file
130
tests/models/rt_detr/test_modeling_rt_detr_resnet.py
Normal file
@ -0,0 +1,130 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import RTDetrResNetConfig
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
from ...test_backbone_common import BackboneTesterMixin
|
||||
from ...test_modeling_common import floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import RTDetrResNetBackbone
|
||||
|
||||
|
||||
class RTDetrResNetModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=3,
|
||||
image_size=32,
|
||||
num_channels=3,
|
||||
embeddings_size=10,
|
||||
hidden_sizes=[10, 20, 30, 40],
|
||||
depths=[1, 1, 2, 1],
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_act="relu",
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
out_features=["stage2", "stage3", "stage4"],
|
||||
out_indices=[2, 3, 4],
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.num_channels = num_channels
|
||||
self.embeddings_size = embeddings_size
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.depths = depths
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.hidden_act = hidden_act
|
||||
self.num_labels = num_labels
|
||||
self.scope = scope
|
||||
self.num_stages = len(hidden_sizes)
|
||||
self.out_features = out_features
|
||||
self.out_indices = out_indices
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.num_labels)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def get_config(self):
|
||||
return RTDetrResNetConfig(
|
||||
num_channels=self.num_channels,
|
||||
embeddings_size=self.embeddings_size,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
depths=self.depths,
|
||||
hidden_act=self.hidden_act,
|
||||
num_labels=self.num_labels,
|
||||
out_features=self.out_features,
|
||||
out_indices=self.out_indices,
|
||||
)
|
||||
|
||||
def create_and_check_backbone(self, config, pixel_values, labels):
|
||||
model = RTDetrResNetBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
|
||||
|
||||
# verify backbone works with out_features=None
|
||||
config.out_features = None
|
||||
model = RTDetrResNetBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), 1)
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), 1)
|
||||
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class RTDetrResNetBackboneTest(BackboneTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (RTDetrResNetBackbone,) if is_torch_available() else ()
|
||||
has_attentions = False
|
||||
config_class = RTDetrResNetConfig
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RTDetrResNetModelTester(self)
|
Loading…
Reference in New Issue
Block a user