Add OmDet-Turbo (#31843)

* Add template with add-new-model-like

* Add rough OmDetTurboEncoder and OmDetTurboDecoder

* Add working OmDetTurbo convert to hf

* Change OmDetTurbo encoder to RT-DETR encoder

* Add swin timm backbone as default, add always partition fix for swin timm

* Add labels and tasks caching

* Fix make fix-copies

* Format omdet_turbo

* fix Tokenizer tests

* Fix style and quality

* Reformat omdet_turbo

* Fix quality, style, copies

* Standardize processor kwargs

* Fix style

* Add output_hidden_states and ouput_attentions

* Add personalize multi-head attention, improve docstrings

* Add integrated test and fix copy, style, quality

* Fix unprotected import

* Cleanup comments and fix unprotected imports

* Add fix different prompts in batch (key_padding_mask)

* Add key_padding_mask to custom multi-head attention module

* Replace attention_mask by key_padding_mask

* Remove OmDetTurboModel and refactor

* Refactor processing of classes and abstract use of timm backbone

* Add testing, fix output attentions and hidden states, add cache for anchors generation

* Fix copies, style, quality

* Add documentation, conver key_padding_mask to attention_mask

* revert changes to backbone_utils

* Fic docstrings rst

* Fix unused argument in config

* Fix image link documentation

* Reorder config and cleanup

* Add tokenizer_init_kwargs in merge_kwargs of the processor

* Change AutoTokenizer to CLIPTokenizer in convert

* Fix init_weights

* Add ProcessorMixin tests, Fix convert while waiting on uniform kwargs

* change processor kwargs and make task input optional

* Fix omdet docs

* Remove unnecessary tests for processor kwargs

* Replace nested BatchEncoding output of the processor by a flattened BatchFeature

* Make modifications from Pavel review

* Add changes Amy review

* Remove unused param

* Remove normalize_before param, Modify processor call docstring

* Remove redundant decoder class, add gradient checkpointing for decoder

* Remove commented out code

* Fix inference in fp16 and add fp16 integrated test

* update omdet md doc

* Add OmdetTurboModel

* fix caching and nit

* add OmDetTurboModel to tests

* nit change repeated key test

* Improve inference speed in eager mode

* fix copies

* Fix nit

* remove OmdetTurboModel

* [run-slow] omdet_turbo

* [run-slow] omdet_turbo

* skip dataparallel test

* [run-slow] omdet_turbo

* update weights to new path

* remove unnecessary config in class

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-91-248.ec2.internal>
This commit is contained in:
Yoni Gozlan 2024-09-25 13:26:28 -04:00 committed by GitHub
parent ade9e0fe41
commit 94f18cf23c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 4354 additions and 1 deletions

View File

@ -862,6 +862,8 @@
title: MGP-STR
- local: model_doc/nougat
title: Nougat
- local: model_doc/omdet-turbo
title: OmDet-Turbo
- local: model_doc/oneformer
title: OneFormer
- local: model_doc/owlvit

View File

@ -237,6 +237,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Nyströmformer](model_doc/nystromformer) | ✅ | ❌ | ❌ |
| [OLMo](model_doc/olmo) | ✅ | ❌ | ❌ |
| [OLMoE](model_doc/olmoe) | ✅ | ❌ | ❌ |
| [OmDet-Turbo](model_doc/omdet-turbo) | ✅ | ❌ | ❌ |
| [OneFormer](model_doc/oneformer) | ✅ | ❌ | ❌ |
| [OpenAI GPT](model_doc/openai-gpt) | ✅ | ✅ | ❌ |
| [OpenAI GPT-2](model_doc/gpt2) | ✅ | ✅ | ✅ |

View File

@ -0,0 +1,164 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# OmDet-Turbo
## Overview
The OmDet-Turbo model was proposed in [Real-time Transformer-based Open-Vocabulary Detection with Efficient Fusion Head](https://arxiv.org/abs/2403.06892) by Tiancheng Zhao, Peng Liu, Xuan He, Lu Zhang, Kyusong Lee. OmDet-Turbo incorporates components from RT-DETR and introduces a swift multimodal fusion module to achieve real-time open-vocabulary object detection capabilities while maintaining high accuracy. The base model achieves performance of up to 100.2 FPS and 53.4 AP on COCO zero-shot.
The abstract from the paper is the following:
*End-to-end transformer-based detectors (DETRs) have shown exceptional performance in both closed-set and open-vocabulary object detection (OVD) tasks through the integration of language modalities. However, their demanding computational requirements have hindered their practical application in real-time object detection (OD) scenarios. In this paper, we scrutinize the limitations of two leading models in the OVDEval benchmark, OmDet and Grounding-DINO, and introduce OmDet-Turbo. This novel transformer-based real-time OVD model features an innovative Efficient Fusion Head (EFH) module designed to alleviate the bottlenecks observed in OmDet and Grounding-DINO. Notably, OmDet-Turbo-Base achieves a 100.2 frames per second (FPS) with TensorRT and language cache techniques applied. Notably, in zero-shot scenarios on COCO and LVIS datasets, OmDet-Turbo achieves performance levels nearly on par with current state-of-the-art supervised models. Furthermore, it establishes new state-of-the-art benchmarks on ODinW and OVDEval, boasting an AP of 30.1 and an NMS-AP of 26.86, respectively. The practicality of OmDet-Turbo in industrial applications is underscored by its exceptional performance on benchmark datasets and superior inference speed, positioning it as a compelling choice for real-time object detection tasks.*
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/omdet_turbo_architecture.jpeg" alt="drawing" width="600"/>
<small> OmDet-Turbo architecture overview. Taken from the <a href="https://arxiv.org/abs/2403.06892">original paper</a>. </small>
This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan).
The original code can be found [here](https://github.com/om-ai-lab/OmDet).
## Usage tips
One unique property of OmDet-Turbo compared to other zero-shot object detection models, such as [Grounding DINO](grounding-dino), is the decoupled classes and prompt embedding structure that allows caching of text embeddings. This means that the model needs both classes and task as inputs, where classes is a list of objects we want to detect and task is the grounded text used to guide open-vocabulary detection. This approach limits the scope of the open-vocabulary detection and makes the decoding process faster.
[`OmDetTurboProcessor`] is used to prepare the classes, task and image triplet. The task input is optional, and when not provided, it will default to `"Detect [class1], [class2], [class3], ..."`. To process the results from the model, one can use `post_process_grounded_object_detection` from [`OmDetTurboProcessor`]. Notably, this function takes in the input classes, as unlike other zero-shot object detection models, the decoupling of classes and task embeddings means that no decoding of the predicted class embeddings is needed in the post-processing step, and the predicted classes can be matched to the inputted ones directly.
## Usage example
### Single image inference
Here's how to load the model and prepare the inputs to perform zero-shot object detection on a single image:
```python
import requests
from PIL import Image
from transformers import AutoProcessor, OmDetTurboForObjectDetection
processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-tiny")
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-tiny")
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
classes = ["cat", "remote"]
inputs = processor(image, text=classes, return_tensors="pt")
outputs = model(**inputs)
# convert outputs (bounding boxes and class logits)
results = processor.post_process_grounded_object_detection(
outputs,
classes=classes,
target_sizes=[image.size[::-1]],
score_threshold=0.3,
nms_threshold=0.3,
)[0]
for score, class_name, box in zip(
results["scores"], results["classes"], results["boxes"]
):
box = [round(i, 1) for i in box.tolist()]
print(
f"Detected {class_name} with confidence "
f"{round(score.item(), 2)} at location {box}"
)
```
### Multi image inference
OmDet-Turbo can perform batched multi-image inference, with support for different text prompts and classes in the same batch:
```python
>>> import torch
>>> import requests
>>> from io import BytesIO
>>> from PIL import Image
>>> from transformers import AutoProcessor, OmDetTurboForObjectDetection
>>> processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
>>> model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
>>> url1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image1 = Image.open(BytesIO(requests.get(url1).content)).convert("RGB")
>>> classes1 = ["cat", "remote"]
>>> task1 = "Detect {}.".format(", ".join(classes1))
>>> url2 = "http://images.cocodataset.org/train2017/000000257813.jpg"
>>> image2 = Image.open(BytesIO(requests.get(url2).content)).convert("RGB")
>>> classes2 = ["boat"]
>>> task2 = "Detect everything that looks like a boat."
>>> url3 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
>>> image3 = Image.open(BytesIO(requests.get(url3).content)).convert("RGB")
>>> classes3 = ["statue", "trees"]
>>> task3 = "Focus on the foreground, detect statue and trees."
>>> inputs = processor(
... images=[image1, image2, image3],
... text=[classes1, classes2, classes3],
... task=[task1, task2, task3],
... return_tensors="pt",
... )
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> # convert outputs (bounding boxes and class logits)
>>> results = processor.post_process_grounded_object_detection(
... outputs,
... classes=[classes1, classes2, classes3],
... target_sizes=[image1.size[::-1], image2.size[::-1], image3.size[::-1]],
... score_threshold=0.2,
... nms_threshold=0.3,
... )
>>> for i, result in enumerate(results):
... for score, class_name, box in zip(
... result["scores"], result["classes"], result["boxes"]
... ):
... box = [round(i, 1) for i in box.tolist()]
... print(
... f"Detected {class_name} with confidence "
... f"{round(score.item(), 2)} at location {box} in image {i}"
... )
Detected remote with confidence 0.77 at location [39.9, 70.4, 176.7, 118.0] in image 0
Detected cat with confidence 0.72 at location [11.6, 54.2, 314.8, 474.0] in image 0
Detected remote with confidence 0.56 at location [333.4, 75.8, 370.7, 187.0] in image 0
Detected cat with confidence 0.55 at location [345.2, 24.0, 639.8, 371.7] in image 0
Detected boat with confidence 0.32 at location [146.9, 219.8, 209.6, 250.7] in image 1
Detected boat with confidence 0.3 at location [319.1, 223.2, 403.2, 238.4] in image 1
Detected boat with confidence 0.27 at location [37.7, 220.3, 84.0, 235.9] in image 1
Detected boat with confidence 0.22 at location [407.9, 207.0, 441.7, 220.2] in image 1
Detected statue with confidence 0.73 at location [544.7, 210.2, 651.9, 502.8] in image 2
Detected trees with confidence 0.25 at location [3.9, 584.3, 391.4, 785.6] in image 2
Detected trees with confidence 0.25 at location [1.4, 621.2, 118.2, 787.8] in image 2
Detected statue with confidence 0.2 at location [428.1, 205.5, 767.3, 759.5] in image 2
```
## OmDetTurboConfig
[[autodoc]] OmDetTurboConfig
## OmDetTurboProcessor
[[autodoc]] OmDetTurboProcessor
- post_process_grounded_object_detection
## OmDetTurboForObjectDetection
[[autodoc]] OmDetTurboForObjectDetection
- forward

View File

@ -609,6 +609,10 @@ _import_structure = {
"models.nystromformer": ["NystromformerConfig"],
"models.olmo": ["OlmoConfig"],
"models.olmoe": ["OlmoeConfig"],
"models.omdet_turbo": [
"OmDetTurboConfig",
"OmDetTurboProcessor",
],
"models.oneformer": [
"OneFormerConfig",
"OneFormerProcessor",
@ -2861,6 +2865,12 @@ else:
"OlmoePreTrainedModel",
]
)
_import_structure["models.omdet_turbo"].extend(
[
"OmDetTurboForObjectDetection",
"OmDetTurboPreTrainedModel",
]
)
_import_structure["models.oneformer"].extend(
[
"OneFormerForUniversalSegmentation",
@ -5407,6 +5417,10 @@ if TYPE_CHECKING:
)
from .models.olmo import OlmoConfig
from .models.olmoe import OlmoeConfig
from .models.omdet_turbo import (
OmDetTurboConfig,
OmDetTurboProcessor,
)
from .models.oneformer import (
OneFormerConfig,
OneFormerProcessor,
@ -7383,6 +7397,10 @@ if TYPE_CHECKING:
OlmoeModel,
OlmoePreTrainedModel,
)
from .models.omdet_turbo import (
OmDetTurboForObjectDetection,
OmDetTurboPreTrainedModel,
)
from .models.oneformer import (
OneFormerForUniversalSegmentation,
OneFormerModel,

View File

@ -173,6 +173,7 @@ from . import (
nystromformer,
olmo,
olmoe,
omdet_turbo,
oneformer,
openai,
opt,

View File

@ -60,6 +60,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("chinese_clip_vision_model", "ChineseCLIPVisionConfig"),
("clap", "ClapConfig"),
("clip", "CLIPConfig"),
("clip_text_model", "CLIPTextConfig"),
("clip_vision_model", "CLIPVisionConfig"),
("clipseg", "CLIPSegConfig"),
("clvp", "ClvpConfig"),
@ -191,6 +192,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("nystromformer", "NystromformerConfig"),
("olmo", "OlmoConfig"),
("olmoe", "OlmoeConfig"),
("omdet-turbo", "OmDetTurboConfig"),
("oneformer", "OneFormerConfig"),
("open-llama", "OpenLlamaConfig"),
("openai-gpt", "OpenAIGPTConfig"),
@ -348,6 +350,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
("clap", "CLAP"),
("clip", "CLIP"),
("clip_text_model", "CLIPTextModel"),
("clip_vision_model", "CLIPVisionModel"),
("clipseg", "CLIPSeg"),
("clvp", "CLVP"),
@ -497,6 +500,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("nystromformer", "Nyströmformer"),
("olmo", "OLMo"),
("olmoe", "OLMoE"),
("omdet-turbo", "OmDet-Turbo"),
("oneformer", "OneFormer"),
("open-llama", "OpenLlama"),
("openai-gpt", "OpenAI GPT"),
@ -665,6 +669,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("xclip", "x_clip"),
("clip_vision_model", "clip"),
("qwen2_audio_encoder", "qwen2_audio"),
("clip_text_model", "clip"),
("siglip_vision_model", "siglip"),
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),

View File

@ -60,6 +60,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
("clap", "ClapModel"),
("clip", "CLIPModel"),
("clip_text_model", "CLIPTextModel"),
("clip_vision_model", "CLIPVisionModel"),
("clipseg", "CLIPSegModel"),
("clvp", "ClvpModelForConditionalGeneration"),
@ -181,6 +182,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("nystromformer", "NystromformerModel"),
("olmo", "OlmoModel"),
("olmoe", "OlmoeModel"),
("omdet-turbo", "OmDetTurboForObjectDetection"),
("oneformer", "OneFormerModel"),
("open-llama", "OpenLlamaModel"),
("openai-gpt", "OpenAIGPTModel"),
@ -812,6 +814,7 @@ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
[
# Model for Zero Shot Object Detection mapping
("grounding-dino", "GroundingDinoForObjectDetection"),
("omdet-turbo", "OmDetTurboForObjectDetection"),
("owlv2", "Owlv2ForObjectDetection"),
("owlvit", "OwlViTForObjectDetection"),
]
@ -1326,6 +1329,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
("albert", "AlbertModel"),
("bert", "BertModel"),
("big_bird", "BigBirdModel"),
("clip_text_model", "CLIPTextModel"),
("data2vec-text", "Data2VecTextModel"),
("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"),

View File

@ -344,6 +344,10 @@ else:
),
("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
(
"omdet-turbo",
("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None),
),
("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
(
"openai-gpt",

View File

@ -0,0 +1,56 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_omdet_turbo": ["OmDetTurboConfig"],
"processing_omdet_turbo": ["OmDetTurboProcessor"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_omdet_turbo"] = [
"OmDetTurboForObjectDetection",
"OmDetTurboPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_omdet_turbo import (
OmDetTurboConfig,
)
from .processing_omdet_turbo import OmDetTurboProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_omdet_turbo import (
OmDetTurboForObjectDetection,
OmDetTurboPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,290 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OmDet-Turbo model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class OmDetTurboConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`OmDetTurboForObjectDetection`].
It is used to instantiate a OmDet-Turbo model according to the specified arguments, defining the model architecture
Instantiating a configuration with the defaults will yield a similar configuration to that of the OmDet-Turbo
[omlab/omdet-turbo-swin-tiny-hf](https://huggingface.co/omlab/omdet-turbo-swin-tiny-hf) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`PretrainedConfig`, *optional*):
The configuration of the text backbone.
backbone_config (`PretrainedConfig`, *optional*):
The configuration of the vision backbone.
use_timm_backbone (`bool`, *optional*, defaults to `True`):
Whether to use the timm for the vision backbone.
backbone (`str`, *optional*, defaults to `"swin_tiny_patch4_window7_224"`):
The name of the pretrained vision backbone to use. If `use_pretrained_backbone=False` a randomly initialized
backbone with the same architecture `backbone` is used.
backbone_kwargs (`dict`, *optional*):
Additional kwargs for the vision backbone.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use a pretrained vision backbone.
apply_layernorm_after_vision_backbone (`bool`, *optional*, defaults to `True`):
Whether to apply layer normalization on the feature maps of the vision backbone output.
image_size (`int`, *optional*, defaults to 640):
The size (resolution) of each image.
disable_custom_kernels (`bool`, *optional*, defaults to `False`):
Whether to disable custom kernels.
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon value for layer normalization.
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon value for batch normalization.
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
text_projection_in_dim (`int`, *optional*, defaults to 512):
The input dimension for the text projection.
text_projection_out_dim (`int`, *optional*, defaults to 512):
The output dimension for the text projection.
task_encoder_hidden_dim (`int`, *optional*, defaults to 1024):
The feedforward dimension for the task encoder.
class_embed_dim (`int`, *optional*, defaults to 512):
The dimension of the classes embeddings.
class_distance_type (`str`, *optional*, defaults to `"cosine"`):
The type of of distance to compare predicted classes to projected classes embeddings.
Can be `"cosine"` or `"dot"`.
num_queries (`int`, *optional*, defaults to 900):
The number of queries.
csp_activation (`str`, *optional*, defaults to `"silu"`):
The activation function of the Cross Stage Partial (CSP) networks of the encoder.
conv_norm_activation (`str`, *optional*, defaults to `"gelu"`):
The activation function of the ConvNormLayer layers of the encoder.
encoder_feedforward_activation (`str`, *optional*, defaults to `"relu"`):
The activation function for the feedforward network of the encoder.
encoder_feedforward_dropout (`float`, *optional*, defaults to 0.0):
The dropout rate following the activation of the encoder feedforward network.
encoder_dropout (`float`, *optional*, defaults to 0.0):
The dropout rate of the encoder multi-head attention module.
hidden_expansion (`int`, *optional*, defaults to 1):
The hidden expansion of the CSP networks in the encoder.
vision_features_channels (`tuple(int)`, *optional*, defaults to `[256, 256, 256]`):
The projected vision features channels used as inputs for the decoder.
encoder_hidden_dim (`int`, *optional*, defaults to 256):
The hidden dimension of the encoder.
encoder_in_channels (`List(int)`, *optional*, defaults to `[192, 384, 768]`):
The input channels for the encoder.
encoder_projection_indices (`List(int)`, *optional*, defaults to `[2]`):
The indices of the input features projected by each layers.
encoder_attention_heads (`int`, *optional*, defaults to 8):
The number of attention heads for the encoder.
encoder_dim_feedforward (`int`, *optional*, defaults to 2048):
The feedforward dimension for the encoder.
encoder_layers (`int`, *optional*, defaults to 1):
The number of layers in the encoder.
positional_encoding_temperature (`int`, *optional*, defaults to 10000):
The positional encoding temperature in the encoder.
num_feature_levels (`int`, *optional*, defaults to 3):
The number of feature levels for the multi-scale deformable attention module of the decoder.
decoder_hidden_dim (`int`, *optional*, defaults to 256):
The hidden dimension of the decoder.
decoder_num_heads (`int`, *optional*, defaults to 8):
The number of heads for the decoder.
decoder_num_layers (`int`, *optional*, defaults to 6):
The number of layers for the decoder.
decoder_activation (`str`, *optional*, defaults to `"relu"`):
The activation function for the decoder.
decoder_dim_feedforward (`int`, *optional*, defaults to 2048):
The feedforward dimension for the decoder.
decoder_num_points (`int`, *optional*, defaults to 4):
The number of points sampled in the decoder multi-scale deformable attention module.
decoder_dropout (`float`, *optional*, defaults to 0.0):
The dropout rate for the decoder.
eval_size (`Tuple[int, int]`, *optional*):
Height and width used to computes the effective height and width of the position embeddings after taking
into account the stride (see RTDetr).
learn_initial_query (`bool`, *optional*, defaults to `False`):
Whether to learn the initial query.
cache_size (`int`, *optional*, defaults to 100):
The cache size for the classes and prompts caches.
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
Whether the model is used as an encoder-decoder model or not.
kwargs (`Dict[str, Any]`, *optional*):
Additional parameters from the architecture. The values in kwargs will be saved as part of the configuration
and can be used to control the model outputs.
Examples:
```python
>>> from transformers import OmDetTurboConfig, OmDetTurboForObjectDetection
>>> # Initializing a OmDet-Turbo omlab/omdet-turbo-tiny style configuration
>>> configuration = OmDetTurboConfig()
>>> # Initializing a model (with random weights) from the omlab/omdet-turbo-tiny style configuration
>>> model = OmDetTurboForObjectDetection(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "omdet-turbo"
attribute_map = {
"encoder_hidden_dim": "d_model",
"num_attention_heads": "encoder_attention_heads",
}
def __init__(
self,
text_config=None,
backbone_config=None,
use_timm_backbone=True,
backbone="swin_tiny_patch4_window7_224",
backbone_kwargs=None,
use_pretrained_backbone=False,
apply_layernorm_after_vision_backbone=True,
image_size=640,
disable_custom_kernels=False,
layer_norm_eps=1e-5,
batch_norm_eps=1e-5,
init_std=0.02,
text_projection_in_dim=512,
text_projection_out_dim=512,
task_encoder_hidden_dim=1024,
class_embed_dim=512,
class_distance_type="cosine",
num_queries=900,
csp_activation="silu",
conv_norm_activation="gelu",
encoder_feedforward_activation="relu",
encoder_feedforward_dropout=0.0,
encoder_dropout=0.0,
hidden_expansion=1,
vision_features_channels=[256, 256, 256],
encoder_hidden_dim=256,
encoder_in_channels=[192, 384, 768],
encoder_projection_indices=[2],
encoder_attention_heads=8,
encoder_dim_feedforward=2048,
encoder_layers=1,
positional_encoding_temperature=10000,
num_feature_levels=3,
decoder_hidden_dim=256,
decoder_num_heads=8,
decoder_num_layers=6,
decoder_activation="relu",
decoder_dim_feedforward=2048,
decoder_num_points=4,
decoder_dropout=0.0,
eval_size=None,
learn_initial_query=False,
cache_size=100,
is_encoder_decoder=True,
**kwargs,
):
if use_timm_backbone:
if backbone_config is None:
backbone_kwargs = {
"out_indices": [1, 2, 3],
"img_size": image_size,
"always_partition": True,
}
elif backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `swin` vision config.")
backbone_config = CONFIG_MAPPING["swin"](
window_size=7,
image_size=image_size,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
out_indices=[2, 3, 4],
)
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
if text_config is None:
logger.info(
"`text_config` is `None`. Initializing the config with the default `clip_text_model` text config."
)
text_config = CONFIG_MAPPING["clip_text_model"]()
elif isinstance(text_config, dict):
text_model_type = text_config.get("model_type")
text_config = CONFIG_MAPPING[text_model_type](**text_config)
if class_distance_type not in ["cosine", "dot"]:
raise ValueError(
f"Invalid `class_distance_type`. It should be either `cosine` or `dot`, but got {class_distance_type}."
)
self.text_config = text_config
self.backbone_config = backbone_config
self.use_timm_backbone = use_timm_backbone
self.backbone = backbone
self.backbone_kwargs = backbone_kwargs
self.use_pretrained_backbone = use_pretrained_backbone
self.apply_layernorm_after_vision_backbone = apply_layernorm_after_vision_backbone
self.image_size = image_size
self.disable_custom_kernels = disable_custom_kernels
self.layer_norm_eps = layer_norm_eps
self.batch_norm_eps = batch_norm_eps
self.init_std = init_std
self.text_projection_in_dim = text_projection_in_dim
self.text_projection_out_dim = text_projection_out_dim
self.task_encoder_hidden_dim = task_encoder_hidden_dim
self.class_embed_dim = class_embed_dim
self.class_distance_type = class_distance_type
self.num_queries = num_queries
self.csp_activation = csp_activation
self.conv_norm_activation = conv_norm_activation
self.encoder_feedforward_activation = encoder_feedforward_activation
self.encoder_feedforward_dropout = encoder_feedforward_dropout
self.encoder_dropout = encoder_dropout
self.hidden_expansion = hidden_expansion
self.vision_features_channels = vision_features_channels
self.encoder_hidden_dim = encoder_hidden_dim
self.encoder_in_channels = encoder_in_channels
self.encoder_projection_indices = encoder_projection_indices
self.encoder_attention_heads = encoder_attention_heads
self.encoder_dim_feedforward = encoder_dim_feedforward
self.encoder_layers = encoder_layers
self.positional_encoding_temperature = positional_encoding_temperature
self.num_feature_levels = num_feature_levels
self.decoder_hidden_dim = decoder_hidden_dim
self.decoder_num_heads = decoder_num_heads
self.decoder_num_layers = decoder_num_layers
self.decoder_activation = decoder_activation
self.decoder_dim_feedforward = decoder_dim_feedforward
self.decoder_num_points = decoder_num_points
self.decoder_dropout = decoder_dropout
self.eval_size = eval_size
self.learn_initial_query = learn_initial_query
self.cache_size = cache_size
self.is_encoder_decoder = is_encoder_decoder
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)

View File

@ -0,0 +1,349 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert OmDet-Turbo checkpoints from the original repository.
URL: https://github.com/om-ai-lab/OmDet"""
import argparse
import requests
import torch
from PIL import Image
from transformers import (
CLIPTokenizer,
DetrImageProcessor,
OmDetTurboConfig,
OmDetTurboForObjectDetection,
OmDetTurboProcessor,
)
IMAGE_MEAN = [123.675, 116.28, 103.53]
IMAGE_STD = [58.395, 57.12, 57.375]
def get_omdet_turbo_config(model_name, use_timm_backbone):
if "tiny" in model_name:
window_size = 7
embed_dim = 96
depths = (2, 2, 6, 2)
num_heads = (3, 6, 12, 24)
image_size = 640
else:
raise ValueError("Model not supported, only supports tiny variant.")
config = OmDetTurboConfig(
backbone_window_size=window_size,
backbone_image_size=image_size,
backbone_embed_dim=embed_dim,
backbone_depths=depths,
backbone_num_heads=num_heads,
backbone_out_indices=(1, 2, 3),
text_config={"model_type": "clip_text_model"},
use_timm_backbone=use_timm_backbone,
backbone="swin_tiny_patch4_window7_224" if use_timm_backbone else None,
apply_layernorm_after_vision_backbone=True if use_timm_backbone else False,
use_pretrained_backbone=False,
)
return config
def create_rename_keys_vision(state_dict, config):
rename_keys = []
# fmt: off
########################################## VISION BACKBONE - START
for layer_name in state_dict.keys():
if layer_name.startswith("backbone") and not layer_name.startswith("backbone.norm"):
if config.use_timm_backbone:
layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone._backbone")
layer_name_replace = layer_name_replace.replace(".layers.", ".layers_")
if "downsample" in layer_name:
# get layer number
layer_num = int(layer_name.split(".")[2])
layer_name_replace = layer_name_replace.replace(f"{layer_num}.downsample", f"{layer_num+1}.downsample")
else:
layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone")
layer_name_replace = layer_name_replace.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
layer_name_replace = layer_name_replace.replace("patch_embed.norm", "embeddings.norm")
if layer_name.startswith("backbone.layers"):
layer_name_replace = layer_name_replace.replace("norm1", "layernorm_before")
layer_name_replace = layer_name_replace.replace("norm2", "layernorm_after")
layer_name_replace = layer_name_replace.replace("attn.proj", "attention.output.dense")
layer_name_replace = layer_name_replace.replace("mlp.fc1", "intermediate.dense")
layer_name_replace = layer_name_replace.replace("mlp.fc2", "output.dense")
layer_name_replace = layer_name_replace.replace(".layers.", ".encoder.layers.")
layer_name_replace = layer_name_replace.replace(".attn.", ".attention.self.")
elif layer_name.startswith("backbone.norm"):
layer_num = int(layer_name.split("norm")[1].split(".")[0])
if config.use_timm_backbone:
layer_name_replace = layer_name.replace("backbone", "vision_backbone")
layer_name_replace = layer_name_replace.replace(f"norm{layer_num}", f"layer_norms.{layer_num-1}")
else:
layer_name_replace = layer_name.replace(f"backbone.norm{layer_num}", f"vision_backbone.vision_backbone.hidden_states_norms.stage{layer_num+1}")
else:
continue
rename_keys.append((layer_name, layer_name_replace))
########################################## VISION BACKBONE - END
########################################## ENCODER - START
for layer_name, params in state_dict.items():
if "neck" in layer_name:
layer_name_replace = layer_name.replace("neck", "encoder")
layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers")
if "fpn_blocks" in layer_name or "pan_blocks" in layer_name or "lateral_convs" in layer_name or "downsample_convs" in layer_name:
layer_name_replace = layer_name_replace.replace(".m.", ".bottlenecks.")
layer_name_replace = layer_name_replace.replace(".cv", ".conv")
layer_name_replace = layer_name_replace.replace(".bn", ".norm")
if "encoder_layer" in layer_name:
layer_name_replace = layer_name_replace.replace("encoder_layer", "encoder.0.layers.0")
layer_name_replace = layer_name_replace.replace(".linear", ".fc")
layer_name_replace = layer_name_replace.replace("norm1", "self_attn_layer_norm")
layer_name_replace = layer_name_replace.replace("norm2", "final_layer_norm")
rename_keys.append((layer_name, layer_name_replace))
########################################## ENCODER - END
########################################## DECODER - START
for layer_name, params in state_dict.items():
if layer_name.startswith("decoder"):
layer_name_replace = layer_name.replace("decoder.decoder.layers", "decoder.layers")
layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers")
layer_name_replace = layer_name_replace.replace("query_pos_head", "query_position_head")
layer_name_replace = layer_name_replace.replace("enc_bbox_head", "encoder_bbox_head")
layer_name_replace = layer_name_replace.replace("enc_output", "encoder_vision_features")
layer_name_replace = layer_name_replace.replace("dec_score_head", "decoder_class_head")
layer_name_replace = layer_name_replace.replace("dec_bbox_head", "decoder_bbox_head")
layer_name_replace = layer_name_replace.replace("enc_score_head", "encoder_class_head")
rename_keys.append((layer_name, layer_name_replace))
########################################## DECODER - END
# fmt: on
return rename_keys
def create_rename_keys_language(state_dict):
rename_keys = []
# fmt: off
for layer_name in state_dict.keys():
if layer_name.startswith("language_backbone") and not layer_name.startswith("language_backbone.text_projection"):
layer_name_replace = layer_name.replace("language_backbone", "language_backbone.model.text_model")
layer_name_replace = layer_name_replace.replace("transformer.resblocks", "encoder.layers")
layer_name_replace = layer_name_replace.replace("token_embedding", "embeddings.token_embedding")
layer_name_replace = layer_name_replace.replace("positional_embedding", "embeddings.position_embedding.weight")
layer_name_replace = layer_name_replace.replace(".attn", ".self_attn")
layer_name_replace = layer_name_replace.replace(".mlp.c_fc", ".mlp.fc1")
layer_name_replace = layer_name_replace.replace(".mlp.c_proj", ".mlp.fc2")
layer_name_replace = layer_name_replace.replace("ln_final", "final_layer_norm")
layer_name_replace = layer_name_replace.replace(".ln_", ".layer_norm")
rename_keys.append((layer_name, layer_name_replace))
# fmt: on
return rename_keys
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v_vision(state_dict, config):
state_dict_keys = list(state_dict.keys())
for layer_name_vision in state_dict_keys:
if layer_name_vision.startswith("vision_backbone") and "qkv" in layer_name_vision:
layer_num = int(layer_name_vision.split(".")[4])
hidden_size = config.backbone_config.embed_dim * 2**layer_num
if "weight" in layer_name_vision:
in_proj_weight = state_dict.pop(layer_name_vision)
state_dict[layer_name_vision.replace("qkv.weight", "key.weight")] = in_proj_weight[:hidden_size, :]
state_dict[layer_name_vision.replace("qkv.weight", "query.weight")] = in_proj_weight[
hidden_size : hidden_size * 2, :
]
state_dict[layer_name_vision.replace("qkv.weight", "value.weight")] = in_proj_weight[-hidden_size:, :]
elif "bias" in layer_name_vision:
in_proj_bias = state_dict.pop(layer_name_vision)
state_dict[layer_name_vision.replace("qkv.bias", "key.bias")] = in_proj_bias[:hidden_size]
state_dict[layer_name_vision.replace("qkv.bias", "query.bias")] = in_proj_bias[
hidden_size : hidden_size * 2
]
state_dict[layer_name_vision.replace("qkv.bias", "value.bias")] = in_proj_bias[-hidden_size:]
def read_in_q_k_v_text(state_dict, config):
state_dict_keys = list(state_dict.keys())
hidden_size = config.text_config.projection_dim
for layer_name_text in state_dict_keys:
if layer_name_text.startswith("language_backbone") and "in_proj" in layer_name_text:
if "weight" in layer_name_text:
in_proj_weight = state_dict.pop(layer_name_text)
state_dict[layer_name_text.replace("in_proj_weight", "q_proj.weight")] = in_proj_weight[
:hidden_size, :
]
state_dict[layer_name_text.replace("in_proj_weight", "k_proj.weight")] = in_proj_weight[
hidden_size : hidden_size * 2, :
]
state_dict[layer_name_text.replace("in_proj_weight", "v_proj.weight")] = in_proj_weight[
-hidden_size:, :
]
elif "bias" in layer_name_text:
in_proj_bias = state_dict.pop(layer_name_text)
state_dict[layer_name_text.replace("in_proj_bias", "q_proj.bias")] = in_proj_bias[:hidden_size]
state_dict[layer_name_text.replace("in_proj_bias", "k_proj.bias")] = in_proj_bias[
hidden_size : hidden_size * 2
]
state_dict[layer_name_text.replace("in_proj_bias", "v_proj.bias")] = in_proj_bias[-hidden_size:]
def read_in_q_k_v_encoder(state_dict, config):
embed_dim = config.encoder_hidden_dim
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop("encoder.encoder.0.layers.0.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop("encoder.encoder.0.layers.0.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict["encoder.encoder.0.layers.0.self_attn.query.weight"] = in_proj_weight[:embed_dim, :]
state_dict["encoder.encoder.0.layers.0.self_attn.query.bias"] = in_proj_bias[:embed_dim]
state_dict["encoder.encoder.0.layers.0.self_attn.key.weight"] = in_proj_weight[embed_dim : embed_dim * 2, :]
state_dict["encoder.encoder.0.layers.0.self_attn.key.bias"] = in_proj_bias[embed_dim : embed_dim * 2]
state_dict["encoder.encoder.0.layers.0.self_attn.value.weight"] = in_proj_weight[-embed_dim:, :]
state_dict["encoder.encoder.0.layers.0.self_attn.value.bias"] = in_proj_bias[-embed_dim:]
def read_in_q_k_v_decoder(state_dict, config):
for layer_num in range(config.decoder_num_layers):
embed_dim = config.decoder_hidden_dim
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"decoder.layers.{layer_num}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"decoder.layers.{layer_num}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"decoder.layers.{layer_num}.self_attn.query.weight"] = in_proj_weight[:embed_dim, :]
state_dict[f"decoder.layers.{layer_num}.self_attn.query.bias"] = in_proj_bias[:embed_dim]
state_dict[f"decoder.layers.{layer_num}.self_attn.key.weight"] = in_proj_weight[embed_dim : embed_dim * 2, :]
state_dict[f"decoder.layers.{layer_num}.self_attn.key.bias"] = in_proj_bias[embed_dim : embed_dim * 2]
state_dict[f"decoder.layers.{layer_num}.self_attn.value.weight"] = in_proj_weight[-embed_dim:, :]
state_dict[f"decoder.layers.{layer_num}.self_attn.value.bias"] = in_proj_bias[-embed_dim:]
def run_test(model, processor):
# We will verify our results on an image of cute cats
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
classes = ["cat", "remote"]
task = "Detect {}.".format(", ".join(classes))
inputs = processor(image, text=classes, task=task, return_tensors="pt")
# Running forward
with torch.no_grad():
outputs = model(**inputs)
predicted_slice = outputs[1][0, :3, :3]
print(predicted_slice)
expected_slice = torch.tensor([[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]])
assert torch.allclose(predicted_slice, expected_slice, atol=1e-4)
print("Looks ok!")
@torch.no_grad()
def convert_omdet_turbo_checkpoint(args):
model_name = args.model_name
pytorch_dump_folder_path = args.pytorch_dump_folder_path
push_to_hub = args.push_to_hub
use_timm_backbone = args.use_timm_backbone
checkpoint_mapping = {
"omdet-turbo-tiny": [
"https://huggingface.co/omlab/OmDet-Turbo_tiny_SWIN_T/resolve/main/OmDet-Turbo_tiny_SWIN_T.pth",
"https://huggingface.co/omlab/OmDet-Turbo_tiny_SWIN_T/resolve/main/ViT-B-16.pt",
],
}
# Define default OmDetTurbo configuation
config = get_omdet_turbo_config(model_name, use_timm_backbone)
# Load original checkpoint
checkpoint_url = checkpoint_mapping[model_name]
original_state_dict_vision = torch.hub.load_state_dict_from_url(checkpoint_url[0], map_location="cpu")["model"]
original_state_dict_vision = {k.replace("module.", ""): v for k, v in original_state_dict_vision.items()}
# Rename keys
new_state_dict = original_state_dict_vision.copy()
rename_keys_vision = create_rename_keys_vision(new_state_dict, config)
rename_keys_language = create_rename_keys_language(new_state_dict)
for src, dest in rename_keys_vision:
rename_key(new_state_dict, src, dest)
for src, dest in rename_keys_language:
rename_key(new_state_dict, src, dest)
if not use_timm_backbone:
read_in_q_k_v_vision(new_state_dict, config)
read_in_q_k_v_text(new_state_dict, config)
read_in_q_k_v_encoder(new_state_dict, config)
read_in_q_k_v_decoder(new_state_dict, config)
# add "model" prefix to all keys
new_state_dict = {f"model.{k}": v for k, v in new_state_dict.items()}
# Load HF model
model = OmDetTurboForObjectDetection(config)
model.eval()
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)
image_processor = DetrImageProcessor(
size={"height": config.backbone_image_size, "width": config.backbone_image_size},
do_rescale=False,
image_mean=IMAGE_MEAN,
image_std=IMAGE_STD,
do_pad=False,
)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
processor = OmDetTurboProcessor(image_processor=image_processor, tokenizer=tokenizer)
# end-to-end consistency test
run_test(model, processor)
if pytorch_dump_folder_path is not None:
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
model.push_to_hub(f"omlab/{model_name}")
processor.push_to_hub(f"omlab/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="omdet-turbo-tiny",
type=str,
choices=["omdet-turbo-tiny"],
help="Name of the OmDetTurbo model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
parser.add_argument(
"--use_timm_backbone", action="store_true", help="Whether or not to use timm backbone for vision backbone."
)
args = parser.parse_args()
convert_omdet_turbo_checkpoint(args)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,362 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for OmDet-Turbo.
"""
import sys
from typing import List, Optional, Tuple, Union
from ...feature_extraction_utils import BatchFeature
from ...image_transforms import center_to_corners_format
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import (
TensorType,
is_torch_available,
is_torchvision_available,
)
if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack
class OmDetTurboTextKwargs(TextKwargs, total=False):
task: Optional[Union[str, List[str], TextInput, PreTokenizedInput]]
class OmDetTurboProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: OmDetTurboTextKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": "max_length",
"truncation": True,
"max_length": 77,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": False,
"return_length": False,
"verbose": True,
"task": None,
},
"images_kwargs": {},
}
if is_torch_available():
import torch
if is_torchvision_available():
from torchvision.ops.boxes import batched_nms
def clip_boxes(box, box_size: Tuple[int, int]):
"""
Clip the boxes by limiting x coordinates to the range [0, width]
and y coordinates to the range [0, height].
Args:
box (Tensor): The box to be clipped.
box_size (height, width): The clipping box's size.
"""
assert torch.isfinite(box).all(), "Box tensor contains infinite or NaN!"
height, width = box_size
x1 = box[:, 0].clamp(min=0, max=width)
y1 = box[:, 1].clamp(min=0, max=height)
x2 = box[:, 2].clamp(min=0, max=width)
y2 = box[:, 3].clamp(min=0, max=height)
box = torch.stack((x1, y1, x2, y2), dim=-1)
return box
def compute_score(boxes):
"""
Compute logit scores per class for each box (proposal) and an array of class indices
corresponding to each proposal, flattened across the proposal_num.
The indices in `classes` will later be used to filter and match the predicted classes
with the input class names.
"""
num_classes = boxes.shape[2]
proposal_num = boxes.shape[1]
scores = torch.sigmoid(boxes)
classes = torch.arange(num_classes, device=boxes.device).unsqueeze(0).repeat(proposal_num, 1).flatten(0, 1)
return scores, classes
def _post_process_boxes_for_image(
boxes: TensorType,
scores: TensorType,
predicted_classes: TensorType,
classes: List[str],
image_size: Tuple[int, int],
num_classes: int,
score_threshold: float,
nms_threshold: float,
max_num_det: int = None,
) -> dict:
"""
Filter predicted results using given thresholds and NMS.
Args:
boxes (torch.Tensor): A Tensor of predicted class-specific or class-agnostic
boxes for the image. Shape : (num_queries, max_num_classes_in_batch * 4) if doing
class-specific regression, or (num_queries, 4) if doing class-agnostic
regression.
scores (torch.Tensor): A Tensor of predicted class scores for the image.
Shape : (num_queries, max_num_classes_in_batch + 1)
predicted_classes (torch.Tensor): A Tensor of predicted classes for the image.
Shape : (num_queries * (max_num_classes_in_batch + 1),)
classes (List[str]): The input classes names.
image_size (tuple): A tuple of (height, width) for the image.
num_classes (int): The number of classes given for this image.
score_threshold (float): Only return detections with a confidence score exceeding this
threshold.
nms_threshold (float): The threshold to use for box non-maximum suppression. Value in [0, 1].
max_num_det (int, optional): The maximum number of detections to return. Default is None.
Returns:
dict: A dictionary the following keys:
"boxes" (Tensor): A tensor of shape (num_filtered_objects, 4), containing the predicted boxes in (x1, y1, x2, y2) format.
"scores" (Tensor): A tensor of shape (num_filtered_objects,), containing the predicted confidence scores for each detection.
"classes" (List[str]): A list of strings, where each string is the predicted class for the
corresponding detection
"""
proposal_num = len(boxes) if max_num_det is None else max_num_det
scores_per_image, topk_indices = scores.flatten(0, 1).topk(proposal_num, sorted=False)
classes_per_image = predicted_classes[topk_indices]
box_pred_per_image = boxes.view(-1, 1, 4).repeat(1, num_classes, 1).view(-1, 4)
box_pred_per_image = box_pred_per_image[topk_indices]
# Score filtering
box_pred_per_image = center_to_corners_format(box_pred_per_image)
box_pred_per_image = box_pred_per_image * torch.tensor(image_size[::-1]).repeat(2).to(box_pred_per_image.device)
filter_mask = scores_per_image > score_threshold # R x K
score_keep = filter_mask.nonzero(as_tuple=False).view(-1)
box_pred_per_image = box_pred_per_image[score_keep]
scores_per_image = scores_per_image[score_keep]
classes_per_image = classes_per_image[score_keep]
filter_classes_mask = classes_per_image < len(classes)
classes_keep = filter_classes_mask.nonzero(as_tuple=False).view(-1)
box_pred_per_image = box_pred_per_image[classes_keep]
scores_per_image = scores_per_image[classes_keep]
classes_per_image = classes_per_image[classes_keep]
# NMS
keep = batched_nms(box_pred_per_image, scores_per_image, classes_per_image, nms_threshold)
box_pred_per_image = box_pred_per_image[keep]
scores_per_image = scores_per_image[keep]
classes_per_image = classes_per_image[keep]
classes_per_image = [classes[i] for i in classes_per_image]
# create an instance
result = {}
result["boxes"] = clip_boxes(box_pred_per_image, image_size)
result["scores"] = scores_per_image
result["classes"] = classes_per_image
return result
class OmDetTurboProcessor(ProcessorMixin):
r"""
Constructs a OmDet-Turbo processor which wraps a Deformable DETR image processor and an AutoTokenizer into a
single processor.
[`OmDetTurboProcessor`] offers all the functionalities of [`DetrImageProcessor`] and
[`AutoTokenizer`]. See the docstring of [`~OmDetTurboProcessor.__call__`] and [`~OmDetTurboProcessor.decode`]
for more information.
Args:
image_processor (`DetrImageProcessor`):
An instance of [`DetrImageProcessor`]. The image processor is a required input.
tokenizer (`AutoTokenizer`):
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "DetrImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)
def __call__(
self,
images: ImageInput = None,
text: Union[List[str], List[List[str]]] = None,
audio=None,
videos=None,
**kwargs: Unpack[OmDetTurboProcessorKwargs],
) -> BatchFeature:
"""
This method uses [*DetrImageProcessor.__call__] method to prepare image(s) for the model, and
[CLIPTokenizerFast.__call__] to prepare text for the model.
Please refer to the docstring of the above two methods for more information.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255.
text (`Union[str, List[str], List[List[str]]]`):
The classes used to limit the scope of the open vocabulary detection. Expects a list of strings or a list
of list of strings. Batched classes can be of different lengths.
Examples: ["cat", "dog", "bird"], [["cat", "dog", "bird"], ["hat", "person"], ["car"]]
Kwargs:
task (`Union[str, List[str], TextInput, PreTokenizedInput]`):
The grounded text used to guide open vocabulary detection. Expects a single string or a list of strings.
Examples: "Detect a cat, a dog, and a bird.",[ "Detect everything.", "Detect trees and flowers."]
When not provided, the default task is "Detect [class1], [class2], [class3]" etc.
...
"""
if images is None or text is None:
raise ValueError("You have to specify both `images` and `text`")
output_kwargs = self._merge_kwargs(
OmDetTurboProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if isinstance(text, str):
text = text.strip(" ").split(",")
if not (len(text) and isinstance(text[0], (list, tuple))):
text = [text]
task = output_kwargs["text_kwargs"].pop("task", None)
if task is None:
task = ["Detect {}.".format(", ".join(text_single)) for text_single in text]
elif not isinstance(task, (list, tuple)):
task = [task]
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
tasks_encoding = self.tokenizer(text=task, **output_kwargs["text_kwargs"])
classes = text
classes_structure = torch.tensor([len(class_single) for class_single in classes], dtype=torch.long)
classes_flattened = [class_single for class_batch in classes for class_single in class_batch]
classes_encoding = self.tokenizer(text=classes_flattened, **output_kwargs["text_kwargs"])
encoding = BatchFeature()
encoding.update({f"tasks_{key}": value for key, value in tasks_encoding.items()})
encoding.update({f"classes_{key}": value for key, value in classes_encoding.items()})
encoding.update({"classes_structure": classes_structure})
encoding.update(encoding_image_processor)
return encoding
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_grounded_object_detection(
self,
outputs,
classes: Union[List[str], List[List[str]]],
score_threshold: float = 0.3,
nms_threshold: float = 0.5,
target_sizes: Optional[Union[TensorType, List[Tuple]]] = None,
max_num_det: Optional[int] = None,
):
"""
Converts the raw output of [`OmDetTurboForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
bottom_right_x, bottom_right_y) format and get the associated text class.
Args:
outputs ([`OmDetTurboObjectDetectionOutput`]):
Raw outputs of the model.
classes (Union[List[str], List[List[str]]]): The input classes names.
score_threshold (float, defaults to 0.3): Only return detections with a confidence score exceeding this
threshold.
nms_threshold (float, defaults to 0.5): The threshold to use for box non-maximum suppression. Value in [0, 1].
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*, defaults to None):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
max_num_det (int, *optional*, defaults to None): The maximum number of detections to return.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, classes and boxes for an image
in the batch as predicted by the model.
"""
if isinstance(classes[0], str):
classes = [classes]
boxes_logits = outputs.decoder_coord_logits
scores_logits = outputs.decoder_class_logits
# Inputs consistency check
if target_sizes is None:
height = (
self.image_processor.size["height"]
if "height" in self.image_processor.size
else self.image_processor.size["shortest_edge"]
)
width = (
self.image_processor.size["width"]
if "width" in self.image_processor.size
else self.image_processor.size["longest_edge"]
)
target_sizes = ((height, width),) * len(boxes_logits)
elif len(target_sizes[0]) != 2:
raise ValueError(
"Each element of target_sizes must contain the size (height, width) of each image of the batch"
)
if len(target_sizes) != len(boxes_logits):
raise ValueError("Make sure that you pass in as many target sizes as output sequences")
if len(classes) != len(boxes_logits):
raise ValueError("Make sure that you pass in as many classes group as output sequences")
# Convert target_sizes to list for easier handling
if isinstance(target_sizes, torch.Tensor):
target_sizes = target_sizes.tolist()
scores, predicted_classes = compute_score(scores_logits)
num_classes = scores_logits.shape[2]
results = []
for scores_img, box_per_img, image_size, class_names in zip(scores, boxes_logits, target_sizes, classes):
results.append(
_post_process_boxes_for_image(
box_per_img,
scores_img,
predicted_classes,
class_names,
image_size,
num_classes,
score_threshold=score_threshold,
nms_threshold=nms_threshold,
max_num_det=max_num_det,
)
)
return results

View File

@ -6587,6 +6587,20 @@ class OlmoePreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class OmDetTurboForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class OmDetTurboPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class OneFormerForUniversalSegmentation(metaclass=DummyObject):
_backends = ["torch"]

View File

View File

@ -0,0 +1,904 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch OmDet-Turbo model."""
import copy
import unittest
from io import BytesIO
import requests
from transformers import OmDetTurboConfig, is_torch_available, is_vision_available
from transformers.feature_extraction_utils import BatchFeature
from transformers.file_utils import cached_property
from transformers.testing_utils import (
require_timm,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
import torch.nn.functional as F
from transformers import OmDetTurboForObjectDetection
if is_vision_available():
from PIL import Image
from transformers import AutoProcessor
class OmDetTurboModelTester:
def __init__(
self,
parent,
batch_size=6,
is_training=False,
num_channels=3,
max_text_len=7,
num_classes=3,
use_timm_backbone=False,
backbone=None,
apply_layernorm_after_vision_backbone=False,
image_size=224,
text_projection_in_dim=16,
text_projection_out_dim=16,
class_embed_dim=16,
hidden_size=8,
num_hidden_layers=2,
num_attention_heads=2,
num_queries=20,
encoder_in_channels=(16, 32, 64),
encoder_dim_feedforward=32,
num_projection_layers=1,
decoder_n_points=4,
num_feature_levels=3,
):
super().__init__()
self.parent = parent
self.batch_size = batch_size
self.is_training = is_training
self.num_channels = num_channels
self.max_text_len = max_text_len
self.num_classes = num_classes
self.use_timm_backbone = use_timm_backbone
self.backbone = backbone
self.apply_layernorm_after_vision_backbone = apply_layernorm_after_vision_backbone
self.image_size = image_size
self.text_projection_in_dim = text_projection_in_dim
self.text_projection_out_dim = text_projection_out_dim
self.class_embed_dim = class_embed_dim
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_queries = num_queries
self.encoder_in_channels = encoder_in_channels
self.encoder_dim_feedforward = encoder_dim_feedforward
self.num_projection_layers = num_projection_layers
self.decoder_n_points = decoder_n_points
self.num_feature_levels = num_feature_levels
self.encoder_seq_length_vision = self.image_size // 32
self.decoder_seq_length = self.num_queries
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
input_ids_tasks = ids_tensor([self.batch_size, self.max_text_len], self.num_classes)
input_ids_tasks = input_ids_tasks.to(torch_device)
input_ids_classes = torch.cat(
[ids_tensor([self.num_classes, self.max_text_len], self.num_classes) for _ in range(self.batch_size)]
)
input_ids_classes = input_ids_classes.to(torch_device)
attention_mask_tasks = torch.ones_like(input_ids_tasks, device=torch_device)
attention_mask_classes = torch.ones_like(input_ids_classes, device=torch_device)
classes_structure = torch.ones(self.batch_size, dtype=torch.long, device=torch_device) * self.num_classes
encoding = BatchFeature()
encoding.update(
{
"pixel_values": pixel_values,
"classes_input_ids": input_ids_classes,
"classes_attention_mask": attention_mask_classes,
"tasks_input_ids": input_ids_tasks,
"tasks_attention_mask": attention_mask_tasks,
"classes_structure": classes_structure,
}
)
config = self.get_config()
return config, encoding
def get_config(self):
text_backbone = {
"hidden_size": 16,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"intermediate_size": 16,
"max_position_embeddings": 8,
"model_type": "clip_text_model",
}
backbone_config = {
"embed_dim": self.hidden_size,
"depths": (1, 1, 1, 1),
"num_heads": (1, 1, 1, 1),
"window_size": 7,
"image_size": self.image_size,
"out_indices": (2, 3, 4),
"model_type": "swin",
}
return OmDetTurboConfig(
text_config=text_backbone,
backbone_config=backbone_config,
use_timm_backbone=self.use_timm_backbone,
backbone=self.backbone,
apply_layernorm_after_vision_backbone=self.apply_layernorm_after_vision_backbone,
decoder_num_layers=self.num_hidden_layers,
image_size=self.image_size,
encoder_in_channels=self.encoder_in_channels,
num_queries=self.num_queries,
encoder_layers=self.num_hidden_layers,
encoder_projection_indices=[2] * self.num_projection_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_num_heads=self.num_attention_heads,
decoder_num_points=self.decoder_n_points,
num_feature_levels=self.num_feature_levels,
encoder_dim_feedforward=self.encoder_dim_feedforward,
task_encoder_hidden_dim=self.encoder_dim_feedforward,
decoder_dim_feedforward=self.encoder_dim_feedforward,
class_embed_dim=self.class_embed_dim,
text_projection_in_dim=self.text_projection_in_dim,
text_projection_out_dim=self.text_projection_out_dim,
encoder_hidden_dim=self.hidden_size,
decoder_hidden_dim=self.hidden_size,
vision_features_channels=[self.hidden_size, self.hidden_size, self.hidden_size],
)
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def create_and_check_object_detection_head_model(self, config, inputs_dict):
model = OmDetTurboForObjectDetection(config=config)
model.to(torch_device)
model.eval()
result = model(**inputs_dict)
self.parent.assertEqual(result.decoder_coord_logits.shape, (self.batch_size, self.num_queries, 4))
self.parent.assertEqual(
result.decoder_class_logits.shape, (self.batch_size, self.num_queries, self.num_classes)
)
@require_torch
class OmDetTurboModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (OmDetTurboForObjectDetection,) if is_torch_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
pipeline_model_mapping = (
{"zero-shot-object-detection": OmDetTurboForObjectDetection} if is_torch_available() else {}
)
# special case for head models
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
return inputs_dict
def setUp(self):
self.model_tester = OmDetTurboModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=OmDetTurboConfig,
has_text_modality=False,
common_properties=["d_model", "encoder_attention_heads", "decoder_num_heads"],
)
def test_config(self):
self.config_tester.run_common_tests()
def test_object_detection_head_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_object_detection_head_model(config, inputs_dict)
@unittest.skip(
reason="Unsupported as classes_input_ids are classes input are flattened by the processor: https://github.com/huggingface/transformers/issues/33669"
)
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip(reason="OmDet-Turbo does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="OmDet-Turbo does not have 'input_ids' and 'attention_mask'")
def test_torchscript_output_attentions(self):
pass
@unittest.skip(reason="OmDet-Turbo does not have 'input_ids' and 'attention_mask'")
def test_torchscript_output_hidden_states(self):
pass
@unittest.skip(reason="OmDet-Turbo does not have 'input_ids' and 'attention_mask'")
def test_torchscript_simple(self):
pass
@unittest.skip(reason="OmDet-Turbo does not have 'input_ids' and 'attention_mask'")
def test_torchscript_output_hidden_state(self):
pass
def test_resize_tokens_embeddings(self):
# rewrite as OmDet-Turbo does not have "input_ids" and "decoder_input_ids"
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model_embed_pre_resize = model.get_input_embeddings()
type_model_embed_pre_resize = type(model_embed_pre_resize)
if self.model_tester.is_training is False:
model.eval()
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
# Retrieve the embeddings and clone theme
model_embed = model.resize_token_embeddings(model_vocab_size)
cloned_embeddings = model_embed.weight.clone()
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
# Check to make sure the type of embeddings returned post resizing is same as type of input
type_model_embed_post_resize = type(model_embed)
self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
# Input ids should be clamped to the maximum size of the vocabulary
inputs_dict["tasks_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
# make sure that classes_input_ids are resized as well
if "classes_input_ids" in inputs_dict:
inputs_dict["classes_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
models_equal = True
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
self.assertTrue(new_model_vocab_size + 10, model_vocab_size)
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
self.assertTrue(new_model_vocab_size, model.vocab_size)
model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
# Check that resizing a model to a multiple of pad_to_multiple leads to a model of exactly that size
target_dimension = 128
model_embed = model.resize_token_embeddings(target_dimension, pad_to_multiple_of=64)
self.assertTrue(model_embed.weight.shape[0], target_dimension)
with self.assertRaisesRegex(
ValueError,
"Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",
):
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
# Overwrite as `init_reference_points` is not batch dependent and contains `inf` values
def test_batching_equivalence(self):
"""
Tests that the model supports batching and that the output is nearly the same for the same input in
different batch sizes.
(Why "nearly the same" not "exactly the same"? Batching uses different matmul shapes, which often leads to
different results: https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535)
"""
def get_tensor_equivalence_function(batched_input):
# models operating on continuous spaces have higher abs difference than LMs
# instead, we can rely on cos distance for image/speech models, similar to `diffusers`
if "input_ids" not in batched_input:
return lambda tensor1, tensor2: (
1.0 - F.cosine_similarity(tensor1.float().flatten(), tensor2.float().flatten(), dim=0, eps=1e-38)
)
return lambda tensor1, tensor2: torch.max(torch.abs(tensor1 - tensor2))
def recursive_check(batched_object, single_row_object, model_name, key):
if isinstance(batched_object, (list, tuple)):
for batched_object_value, single_row_object_value in zip(batched_object, single_row_object):
recursive_check(batched_object_value, single_row_object_value, model_name, key)
elif isinstance(batched_object, dict):
for batched_object_value, single_row_object_value in zip(
batched_object.values(), single_row_object.values()
):
recursive_check(batched_object_value, single_row_object_value, model_name, key)
# do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects
elif batched_object is None or not isinstance(batched_object, torch.Tensor):
return
elif batched_object.dim() == 0:
return
elif key != "init_reference_points":
# init
# indexing the first element does not always work
# e.g. models that output similarity scores of size (N, M) would need to index [0, 0]
slice_ids = [slice(0, index) for index in single_row_object.shape]
batched_row = batched_object[slice_ids]
self.assertFalse(
torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}"
)
self.assertFalse(
torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}"
)
self.assertFalse(
torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}"
)
self.assertFalse(
torch.isinf(single_row_object).any(),
f"Single row output has `inf` in {model_name} for key={key}",
)
self.assertTrue(
(equivalence(batched_row, single_row_object)) <= 1e-03,
msg=(
f"Batched and Single row outputs are not equal in {model_name} for key={key}. "
f"Difference={equivalence(batched_row, single_row_object)}."
),
)
config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
equivalence = get_tensor_equivalence_function(batched_input)
for model_class in self.all_model_classes:
config.output_hidden_states = True
model_name = model_class.__name__
if hasattr(self.model_tester, "prepare_config_and_inputs_for_model_class"):
config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class)
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
model = model_class(config).to(torch_device).eval()
batch_size = self.model_tester.batch_size
single_row_input = {}
for key, value in batched_input_prepared.items():
single_batch_shape = value.shape[0] // batch_size
single_row_input[key] = value[:single_batch_shape]
with torch.no_grad():
model_batched_output = model(**batched_input_prepared)
model_row_output = model(**single_row_input)
if isinstance(model_batched_output, torch.Tensor):
model_batched_output = {"model_output": model_batched_output}
model_row_output = {"model_output": model_row_output}
for key in model_batched_output:
# DETR starts from zero-init queries to decoder, leading to cos_similarity = `nan`
if hasattr(self, "zero_init_hidden_state") and "decoder_hidden_states" in key:
model_batched_output[key] = model_batched_output[key][1:]
model_row_output[key] = model_row_output[key][1:]
if key in ("decoder_class_logits", "decoder_classes", "encoder_class_logits"):
# check if all elements are close to 0, if so skip the test as the test strugles with comparing
# tensors with all elements close to 0
if torch.allclose(
model_batched_output[key], torch.zeros_like(model_batched_output[key]), atol=1e-6
) and torch.allclose(model_row_output[key], torch.zeros_like(model_row_output[key]), atol=1e-6):
continue
recursive_check(model_batched_output[key], model_row_output[key], model_name, key)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions[-1]
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions[-1]
self.assertEqual(
len(attentions), self.model_tester.num_hidden_layers * self.model_tester.num_projection_layers
)
# Rest of the shape seems to depend on backbone output shapes and image size
self.assertListEqual(
list(attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
self.model_tester.encoder_seq_length_vision**2,
self.model_tester.encoder_seq_length_vision**2,
],
)
# decoder attentions
decoder_attentions = outputs.decoder_attentions[0]
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
self.model_tester.num_queries + self.model_tester.max_text_len,
self.model_tester.num_queries + self.model_tester.max_text_len,
],
)
# cross attentions
cross_attentions = outputs.decoder_attentions[-1]
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
self.model_tester.num_feature_levels,
self.model_tester.decoder_n_points,
],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self_attentions = outputs.encoder_attentions[-1]
self.assertEqual(
len(self_attentions), self.model_tester.num_hidden_layers * self.model_tester.num_projection_layers
)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
self.model_tester.encoder_seq_length_vision**2,
self.model_tester.encoder_seq_length_vision**2,
],
)
# overwrite since encoder_hidden_states are 3-dim and not 2-dim
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_projection_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
seq_len = self.model_tester.encoder_seq_length_vision
self.assertListEqual(list(hidden_states[0].shape[-3:]), [self.model_tester.hidden_size, seq_len, seq_len])
hidden_states = outputs.decoder_hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertIsInstance(hidden_states, (list, tuple))
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.decoder_seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# removed retain_grad and grad on decoder_hidden_states, as queries don't require grad
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**inputs)
output = outputs[0]
encoder_hidden_states = outputs.encoder_hidden_states[0]
encoder_attentions = outputs.encoder_attentions[0][0]
encoder_hidden_states.retain_grad()
encoder_attentions.retain_grad()
cross_attentions = outputs.decoder_attentions[-1][0]
cross_attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(encoder_hidden_states.grad)
self.assertIsNotNone(encoder_attentions.grad)
self.assertIsNotNone(cross_attentions.grad)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
if (
"embeddings" in name
or ".fc" in name
or "decoder.channel_projection_layers" in name
or "query_position_head" in name
or "decoder.encoder_vision_features" in name
):
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} seems not properly initialized",
)
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
return image
def prepare_text():
classes = ["cat", "remote"]
task = "Detect {}.".format(", ".join(classes))
return classes, task
def prepare_img_batched():
url1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
url2 = "http://images.cocodataset.org/train2017/000000257813.jpg"
url3 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
return [Image.open(BytesIO(requests.get(url).content)).convert("RGB") for url in [url1, url2, url3]]
def prepare_text_batched():
classes1 = ["cat", "remote"]
classes2 = ["boat"]
classes3 = ["statue", "trees", "torch"]
task1 = "Detect {}.".format(", ".join(classes1))
task2 = "Detect all the boat in the image."
task3 = "Focus on the foreground, detect statue, torch and trees."
return [classes1, classes2, classes3], [task1, task2, task3]
@require_timm
@require_vision
@slow
class OmDetTurboModelIntegrationTests(unittest.TestCase):
@cached_property
def default_processor(self):
return AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") if is_vision_available() else None
def test_inference_object_detection_head(self):
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(torch_device)
processor = self.default_processor
image = prepare_img()
classes, task = prepare_text()
encoding = processor(images=image, text=classes, task=task, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**encoding)
expected_shape_coord_logits = torch.Size((1, model.config.num_queries, 4))
expected_shape_class_logits = torch.Size((1, model.config.num_queries, 2))
self.assertEqual(outputs.decoder_coord_logits.shape, expected_shape_coord_logits)
self.assertEqual(outputs.decoder_class_logits.shape, expected_shape_class_logits)
expected_class_logits = torch.tensor([[[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]]).to(
torch_device
)
expected_coord_logits = torch.tensor(
[[[0.2550, 0.5501, 0.4738, 0.8745], [0.7695, 0.4121, 0.4603, 0.7244], [0.7691, 0.4117, 0.4603, 0.7214]]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.decoder_class_logits[:3, :3], expected_class_logits, atol=1e-1))
self.assertTrue(torch.allclose(outputs.decoder_coord_logits[:3, :3], expected_coord_logits, atol=1e-3))
# verify grounded postprocessing
results = processor.post_process_grounded_object_detection(
outputs, classes=[classes], target_sizes=[image.size[::-1]]
)[0]
expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device)
expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(torch_device)
self.assertEqual(len(results["scores"]), 4)
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2))
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2))
expected_classes = ["remote", "cat", "remote", "cat"]
self.assertListEqual(results["classes"], expected_classes)
def test_inference_object_detection_head_fp16(self):
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(
torch_device, dtype=torch.float16
)
processor = self.default_processor
image = prepare_img()
classes, task = prepare_text()
encoding = processor(images=image, text=classes, task=task, return_tensors="pt").to(
torch_device, dtype=torch.float16
)
with torch.no_grad():
outputs = model(**encoding)
expected_shape_coord_logits = torch.Size((1, model.config.num_queries, 4))
expected_shape_class_logits = torch.Size((1, model.config.num_queries, 2))
self.assertEqual(outputs.decoder_coord_logits.shape, expected_shape_coord_logits)
self.assertEqual(outputs.decoder_class_logits.shape, expected_shape_class_logits)
expected_class_logits = torch.tensor([[[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]]).to(
torch_device, dtype=torch.float16
)
expected_coord_logits = torch.tensor(
[[[0.2550, 0.5501, 0.4738, 0.8745], [0.7695, 0.4121, 0.4603, 0.7244], [0.7691, 0.4117, 0.4603, 0.7214]]]
).to(torch_device, dtype=torch.float16)
self.assertTrue(torch.allclose(outputs.decoder_class_logits[:3, :3], expected_class_logits, atol=1e-1))
self.assertTrue(torch.allclose(outputs.decoder_coord_logits[:3, :3], expected_coord_logits, atol=1e-3))
# verify grounded postprocessing
results = processor.post_process_grounded_object_detection(
outputs, classes=[classes], target_sizes=[image.size[::-1]]
)[0]
expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device, dtype=torch.float16)
expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(
torch_device, dtype=torch.float16
)
self.assertEqual(len(results["scores"]), 4)
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2))
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-1))
expected_classes = ["remote", "cat", "remote", "cat"]
self.assertListEqual(results["classes"], expected_classes)
def test_inference_object_detection_head_no_task(self):
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(torch_device)
processor = self.default_processor
image = prepare_img()
classes, _ = prepare_text()
encoding = processor(images=image, text=classes, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**encoding)
expected_shape_coord_logits = torch.Size((1, model.config.num_queries, 4))
expected_shape_class_logits = torch.Size((1, model.config.num_queries, 2))
self.assertEqual(outputs.decoder_coord_logits.shape, expected_shape_coord_logits)
self.assertEqual(outputs.decoder_class_logits.shape, expected_shape_class_logits)
expected_class_logits = torch.tensor([[[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]]).to(
torch_device
)
expected_coord_logits = torch.tensor(
[[[0.2550, 0.5501, 0.4738, 0.8745], [0.7695, 0.4121, 0.4603, 0.7244], [0.7691, 0.4117, 0.4603, 0.7214]]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.decoder_class_logits[:3, :3], expected_class_logits, atol=1e-1))
self.assertTrue(torch.allclose(outputs.decoder_coord_logits[:3, :3], expected_coord_logits, atol=1e-3))
# verify grounded postprocessing
results = processor.post_process_grounded_object_detection(
outputs, classes=[classes], target_sizes=[image.size[::-1]]
)[0]
expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device)
expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(torch_device)
self.assertEqual(len(results["scores"]), 4)
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2))
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2))
expected_classes = ["remote", "cat", "remote", "cat"]
self.assertListEqual(results["classes"], expected_classes)
def test_inference_object_detection_head_batched(self):
torch_device = "cpu"
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(torch_device)
processor = self.default_processor
images_batched = prepare_img_batched()
classes_batched, tasks_batched = prepare_text_batched()
encoding = processor(images=images_batched, text=classes_batched, task=tasks_batched, return_tensors="pt").to(
torch_device
)
with torch.no_grad():
outputs = model(**encoding)
expected_shape_coord_logits = torch.Size((len(images_batched), model.config.num_queries, 4))
expected_shape_class_logits = torch.Size((len(images_batched), model.config.num_queries, 3))
self.assertEqual(outputs.decoder_coord_logits.shape, expected_shape_coord_logits)
self.assertEqual(outputs.decoder_class_logits.shape, expected_shape_class_logits)
expected_class_logits = torch.tensor(
[[[0.9427, -2.5958, -7.7601]], [[-2.3408, -9.3516, -9.3516]], [[1.0740, -2.3315, -1.1885]]]
).to(torch_device)
expected_coord_logits = torch.tensor(
[[[0.2550, 0.5501, 0.4738]], [[0.2535, 0.6006, 0.0353]], [[0.3742, 0.3337, 0.0666]]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.decoder_class_logits[:, :1, :3], expected_class_logits, atol=1e-1))
self.assertTrue(torch.allclose(outputs.decoder_coord_logits[:, :1, :3], expected_coord_logits, atol=1e-3))
# verify grounded postprocessing
results = processor.post_process_grounded_object_detection(
outputs,
classes=classes_batched,
target_sizes=[image.size[::-1] for image in images_batched],
score_threshold=0.2,
)
expected_scores = torch.tensor([0.7675, 0.3016, 0.7454]).to(torch_device)
expected_slice_boxes = torch.tensor(
[
[39.8870, 70.3522, 176.7424, 118.0354],
[146.5446, 219.7132, 209.6983, 251.0456],
[545.3470, 209.9055, 651.9860, 502.1882],
]
).to(torch_device)
self.assertListEqual([len(result["scores"]) for result in results], [4, 4, 6])
self.assertTrue(
torch.allclose(torch.stack([result["scores"][0] for result in results]), expected_scores, atol=1e-2)
)
self.assertTrue(
torch.allclose(torch.stack([result["boxes"][0, :] for result in results]), expected_slice_boxes, atol=1e-2)
)
expected_classes = [
["remote", "cat", "remote", "cat"],
["boat", "boat", "boat", "boat"],
["statue", "trees", "trees", "torch", "statue", "statue"],
]
self.assertListEqual([result["classes"] for result in results], expected_classes)
@require_torch_gpu
def test_inference_object_detection_head_equivalence_cpu_gpu(self):
processor = self.default_processor
image = prepare_img()
classes, task = prepare_text()
encoding = processor(images=image, text=classes, task=task, return_tensors="pt")
# 1. run model on CPU
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
with torch.no_grad():
cpu_outputs = model(**encoding)
# 2. run model on GPU
model.to("cuda")
encoding = encoding.to("cuda")
with torch.no_grad():
gpu_outputs = model(**encoding)
# 3. assert equivalence
expected_class_logits = torch.tensor([[[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]])
expected_coord_logits = torch.tensor(
[[[0.2550, 0.5501, 0.4738, 0.8745], [0.7695, 0.4121, 0.4603, 0.7244], [0.7691, 0.4117, 0.4603, 0.7214]]]
)
self.assertTrue(torch.allclose(cpu_outputs.decoder_class_logits[:3, :3], expected_class_logits, atol=1e-1))
self.assertTrue(torch.allclose(cpu_outputs.decoder_coord_logits[:3, :3], expected_coord_logits, atol=1e-3))
# verify grounded postprocessing
results_cpu = processor.post_process_grounded_object_detection(
cpu_outputs, classes=[classes], target_sizes=[image.size[::-1]]
)[0]
result_gpu = processor.post_process_grounded_object_detection(
gpu_outputs, classes=[classes], target_sizes=[image.size[::-1]]
)[0]
self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-2))
self.assertTrue(torch.allclose(results_cpu["boxes"][0, :], result_gpu["boxes"][0, :].cpu(), atol=1e-2))

View File

@ -0,0 +1,363 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
import numpy as np
import pytest
from transformers import AutoProcessor, CLIPTokenizerFast, OmDetTurboProcessor
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
IMAGE_MEAN = [123.675, 116.28, 103.53]
IMAGE_STD = [58.395, 57.12, 57.375]
if is_torch_available():
import torch
from transformers.models.omdet_turbo.modeling_omdet_turbo import OmDetTurboObjectDetectionOutput
if is_vision_available():
from PIL import Image
from transformers import DetrImageProcessor
@require_torch
@require_vision
class OmDetTurboProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = OmDetTurboProcessor
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = DetrImageProcessor()
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32")
processor = OmDetTurboProcessor(image_processor, tokenizer)
processor.save_pretrained(self.tmpdirname)
self.input_keys = [
"tasks_input_ids",
"tasks_attention_mask",
"classes_input_ids",
"classes_attention_mask",
"classes_structure",
"pixel_values",
"pixel_mask",
]
self.batch_size = 5
self.num_queries = 5
self.embed_dim = 3
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
def get_fake_omdet_turbo_output(self):
torch.manual_seed(42)
return OmDetTurboObjectDetectionOutput(
decoder_coord_logits=torch.rand(self.batch_size, self.num_queries, 4),
decoder_class_logits=torch.rand(self.batch_size, self.num_queries, self.embed_dim),
)
def get_fake_omdet_turbo_classes(self):
return [[f"class{i}_{j}" for i in range(self.num_queries)] for j in range(self.batch_size)]
def test_post_process_grounded_object_detection(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor)
omdet_turbo_output = self.get_fake_omdet_turbo_output()
omdet_turbo_classes = self.get_fake_omdet_turbo_classes()
post_processed = processor.post_process_grounded_object_detection(
omdet_turbo_output, omdet_turbo_classes, target_sizes=[(400, 30) for _ in range(self.batch_size)]
)
self.assertEqual(len(post_processed), self.batch_size)
self.assertEqual(list(post_processed[0].keys()), ["boxes", "scores", "classes"])
self.assertEqual(post_processed[0]["boxes"].shape, (self.num_queries, 4))
self.assertEqual(post_processed[0]["scores"].shape, (self.num_queries,))
expected_scores = torch.tensor([0.7310, 0.6579, 0.6513, 0.6444, 0.6252])
self.assertTrue(torch.allclose(post_processed[0]["scores"], expected_scores, atol=1e-4))
expected_box_slice = torch.tensor([14.9657, 141.2052, 30.0000, 312.9670])
self.assertTrue(torch.allclose(post_processed[0]["boxes"][0], expected_box_slice, atol=1e-4))
def test_save_load_pretrained_additional_features(self):
processor = OmDetTurboProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
processor = OmDetTurboProcessor.from_pretrained(
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, CLIPTokenizerFast)
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
self.assertIsInstance(processor.image_processor, DetrImageProcessor)
def test_image_processor(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor).image_processor
image_input = self.prepare_image_inputs()
input_image_proc = image_processor(image_input, return_tensors="np")
input_processor = processor(images=image_input, return_tensors="np")
for key in input_image_proc.keys():
self.assertAlmostEqual(input_image_proc[key].sum(), input_processor[key].sum(), delta=1e-2)
def test_tokenizer(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor).tokenizer
input_str = "lower newer"
encoded_processor = processor(text=input_str, padding="max_length", truncation=True, max_length=77)
encoded_tok = tokenizer(input_str, padding="max_length", truncation=True, max_length=77)
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_processor(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor)
input_tasks = "task"
input_classes = ["class1", "class2"]
image_input = self.prepare_image_inputs()
input_processor = processor(images=image_input, text=input_classes, task=input_tasks, return_tensors="pt")
for key in self.input_keys:
assert torch.is_tensor(input_processor[key])
# test if it raises when no input is passed
with pytest.raises(ValueError):
processor()
def test_tokenizer_decode(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor)
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
decoded_processor = processor.batch_decode(predicted_ids)
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor)
input_tasks = "task"
input_classes = ["class1", "class2"]
image_input = self.prepare_image_inputs()
inputs = processor(images=image_input, text=input_classes, task=input_tasks, return_tensors="pt")
self.assertListEqual(list(inputs.keys()), self.input_keys)
@require_vision
@require_torch
def test_tokenizer_defaults_preserved_by_kwargs(self):
# Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes.
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117)
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(images=image_input, text=[input_str], task=input_str, return_tensors="pt")
self.assertEqual(len(inputs["tasks_input_ids"][0]), 117)
self.assertEqual(len(inputs["classes_input_ids"][0]), 117)
@require_vision
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs(self):
# Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes.
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117)
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(images=image_input, text=[input_str], task=input_str, return_tensors="pt", max_length=112)
self.assertEqual(len(inputs["tasks_input_ids"][0]), 112)
self.assertEqual(len(inputs["classes_input_ids"][0]), 112)
@require_torch
@require_vision
def test_unstructured_kwargs(self):
# Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes.
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
images=image_input,
text=[input_str],
task=input_str,
return_tensors="pt",
size={"height": 214, "width": 214},
padding="max_length",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["tasks_input_ids"][0]), 76)
self.assertEqual(len(inputs["classes_input_ids"][0]), 76)
@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
# Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes.
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
images=image_input,
text=[input_str],
task=input_str,
return_tensors="pt",
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["tasks_input_ids"][0]), 6)
self.assertEqual(len(inputs["classes_input_ids"][0]), 6)
@require_torch
@require_vision
def test_structured_kwargs_nested(self):
# Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes.
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76, "task": input_str},
}
inputs = processor(images=image_input, text=[input_str], **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["tasks_input_ids"][0]), 76)
self.assertEqual(len(inputs["classes_input_ids"][0]), 76)
@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
# Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes.
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76, "task": input_str},
}
inputs = processor(images=image_input, text=[input_str], **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["tasks_input_ids"][0]), 76)
self.assertEqual(len(inputs["classes_input_ids"][0]), 76)

View File

@ -173,7 +173,13 @@ MODEL_NAMES_WITH_SAME_CONFIG = {
"XLS-R": "Wav2Vec2",
"XLSR-Wav2Vec2": "Wav2Vec2",
}
MODEL_NAMES_TO_IGNORE = ["CLIPVisionModel", "SiglipVisionModel", "ChineseCLIPVisionModel", "Qwen2AudioEncoder"]
MODEL_NAMES_TO_IGNORE = [
"ChineseCLIPVisionModel",
"CLIPTextModel",
"CLIPVisionModel",
"Qwen2AudioEncoder",
"SiglipVisionModel",
]
def get_model_table_from_auto_modules() -> str: