
* fill out docs string in configuration75dcd3a0e8 (r1506391856)
* reduce the input image size for the tests * remove the unappropriate tests * only 5 failes exists * make style * fill up missed architecture for object detection in docs * fix auto modeling * simple fix in missing import * major change including backbone refactor and objectdetectionoutput refactor * minor fix only 4 fails left * intermediate fix * revert __init__.py * revert __init__.py * make style * fixes in pr_docs * intermediate fix * make style * two fixes * pass doctest * only one fix left * intermediate commit * all fixed * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/rt_detr/test_modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * function class above the model definition in dice_loss * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * simple fix * layernorm add config.layer_norm_eps * fix inputs_docstring * make style * simple fix * add custom coco loading test in image_processor * fix error in BaseModelOutput https://github.com/huggingface/transformers/pull/29077#discussion_r1516657790 * simple typo * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * intermediate fix * fix with load_backbone format * remove unused configuration * 3 fix test left * make style * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: Sounak Dey <dey.sounak@gmail.com> * change last_hidden_state to first index * all pass fix TO DO: minor update in comments * make fix-copies * remove deepcopy * pr_document fix * revert deepcopy due to the issue of unexpceted behavior in decoderlayer * add atol in final * add no_split_module * _no_split_modules = None * device transfer for model parallelism * minor fix * make fix-copies * fix typo * add test_image_processor with post_processing * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add config in RTDETRPredictionHead * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * set lru_cache with max_size 32 * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add lru_cache import and configuration change * change the order of definition * make fix-copies * add docs and change config error * revert strange make-fix * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * test pass * fix get_clones related and remove deepcopy * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * nit for paper section * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * rename denoising related parameters * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * check the image transformation logic * make style * make style * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * pe_encoding -> positional_encoding_temperature * remove TODO * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * remove eval_idx since transformer DETR is giving all decoder output * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * change variable name * make style and docs import update * Revert "Update src/transformers/models/rt_detr/image_processing_rt_detr.py" This reverts commit74aa3e1de0
. * fix typo * add postprocessing in docs * move import scipy to top * change varaible name * make fix-copies * remove eval_idx in test * move to after first sentence * update image_processor since box loss requires normalized one * change appropriate name to auxiliary_outputs * Update src/transformers/models/rt_detr/__init__.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/rt_detr/__init__.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * make style * remove panoptic related comments * make style * revert valid_processor_keys * fix aux related test * make style * change origination from config to backbone API * enable the dn_loss * fix test and conversion * renewal weight initialization * change initializer_range * make fix-up * fix the loss issue in the auxiliary output and denoising part * change weight loss to original RTDETR * fix in initialization * sync shape format of dn and aux * make style * stable fine-tuning and compatible conversion for resnet101 * make style * skip input_embed * change encoder related variable * enable converting rtdetr_r101 * add r101 related conversion code * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/image_processing_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * change name _shape to _reshape * Update src/transformers/__init__.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * maket style * make fix-copies * remove deprecated import * more fix * remove last_hidden_state for task-specific model * Revert "remove last_hidden_state for task-specific model" This reverts commitccb7a34051
. * minore change in convert * remove print * make style and fix-copies * add custom rtdetr backbone for r18, r34 * remove print * change copied * add pad_size * make style * change layertype to optional to pass the CI * make style * add test in modeling_resnet_rt_detr * make fix-copies * skip tmp file test * fix comment * add docs * change to modeling_resnet file format * enabling resnet50 above * Update src/transformers/models/rt_detr/modeling_rt_detr.py Co-authored-by: Jason Wu <jasonkit@users.noreply.github.com> * enable all the rtdetr model :) * finish except CI * add RTDetrResNetBackbone * make fix-copies * fix TO DO: CI enable * make style * rename test * add docs * add special fix * revert resnet * Update src/transformers/models/rt_detr/modeling_rt_detr_resnet.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * add more comment * remove swin comment * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * rename convert and add verify backbone * Update docs/source/en/_toctree.yml Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/rt_detr.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * make style * requests for docs * more general test docs * general script docs * make fix-copies * final commit * Revert "Update src/transformers/models/rt_detr/configuration_rt_detr.py" This reverts commitd136225cd3
. * skip test_model_get_set_embeddings * remove target * add changes * make fix-copies * remove decoder_attention_mask * add load_backbone function for auto_backbone * remove comment * fix repo name * Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * final commit * remove unused downsample_in_bottleneck * new test for autobackbone * change to appropriate indices * test fix * fix dict in test_image_processor * fix test * [run-slow] rt_detr, rt_detr_resnet * change the slow test * [run-slow] rt_detr * [run-slow] rt_detr, rt_detr_resnet * make in to same cuda in CSPRepLayer * [run-slow] rt_detr, rt_detr_resnet --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sounak Dey <dey.sounak@gmail.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Jason Wu <jasonkit@users.noreply.github.com> Co-authored-by: ChoiSangBum <choisangbum@ChoiSangBumui-MacBookPro.local>
4.8 KiB
RT-DETR
Overview
The RT-DETR model was proposed in DETRs Beat YOLOs on Real-time Object Detection by Wenyu Lv, Yian Zhao, Shangliang Xu, Jinman Wei, Guanzhong Wang, Cheng Cui, Yuning Du, Qingqing Dang, Yi Liu.
RT-DETR is an object detection model that stands for "Real-Time DEtection Transformer." This model is designed to perform object detection tasks with a focus on achieving real-time performance while maintaining high accuracy. Leveraging the transformer architecture, which has gained significant popularity in various fields of deep learning, RT-DETR processes images to identify and locate multiple objects within them.
The abstract from the paper is the following:
Recently, end-to-end transformer-based detectors (DETRs) have achieved remarkable performance. However, the issue of the high computational cost of DETRs has not been effectively addressed, limiting their practical application and preventing them from fully exploiting the benefits of no post-processing, such as non-maximum suppression (NMS). In this paper, we first analyze the influence of NMS in modern real-time object detectors on inference speed, and establish an end-to-end speed benchmark. To avoid the inference delay caused by NMS, we propose a Real-Time DEtection TRansformer (RT-DETR), the first real-time end-to-end object detector to our best knowledge. Specifically, we design an efficient hybrid encoder to efficiently process multi-scale features by decoupling the intra-scale interaction and cross-scale fusion, and propose IoU-aware query selection to improve the initialization of object queries. In addition, our proposed detector supports flexibly adjustment of the inference speed by using different decoder layers without the need for retraining, which facilitates the practical application of real-time object detectors. Our RT-DETR-L achieves 53.0% AP on COCO val2017 and 114 FPS on T4 GPU, while RT-DETR-X achieves 54.8% AP and 74 FPS, outperforming all YOLO detectors of the same scale in both speed and accuracy. Furthermore, our RT-DETR-R50 achieves 53.1% AP and 108 FPS, outperforming DINO-Deformable-DETR-R50 by 2.2% AP in accuracy and by about 21 times in FPS.
The model version was contributed by rafaelpadilla and sangbumchoi. The original code can be found here.
Usage tips
Initially, an image is processed using a pre-trained convolutional neural network, specifically a Resnet-D variant as referenced in the original code. This network extracts features from the final three layers of the architecture. Following this, a hybrid encoder is employed to convert the multi-scale features into a sequential array of image features. Then, a decoder, equipped with auxiliary prediction heads is used to refine the object queries. This process facilitates the direct generation of bounding boxes, eliminating the need for any additional post-processing to acquire the logits and coordinates for the bounding boxes.
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
from PIL import Image
import json
import torch
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
inputs = image_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]), threshold=0.3)
RTDetrConfig
autodoc RTDetrConfig
RTDetrResNetConfig
autodoc RTDetrResNetConfig
RTDetrImageProcessor
autodoc RTDetrImageProcessor - preprocess - post_process_object_detection
RTDetrModel
autodoc RTDetrModel - forward
RTDetrForObjectDetection
autodoc RTDetrForObjectDetection - forward
RTDetrResNetBackbone
autodoc RTDetrResNetBackbone - forward