mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add OmDet-Turbo (#31843)
* Add template with add-new-model-like * Add rough OmDetTurboEncoder and OmDetTurboDecoder * Add working OmDetTurbo convert to hf * Change OmDetTurbo encoder to RT-DETR encoder * Add swin timm backbone as default, add always partition fix for swin timm * Add labels and tasks caching * Fix make fix-copies * Format omdet_turbo * fix Tokenizer tests * Fix style and quality * Reformat omdet_turbo * Fix quality, style, copies * Standardize processor kwargs * Fix style * Add output_hidden_states and ouput_attentions * Add personalize multi-head attention, improve docstrings * Add integrated test and fix copy, style, quality * Fix unprotected import * Cleanup comments and fix unprotected imports * Add fix different prompts in batch (key_padding_mask) * Add key_padding_mask to custom multi-head attention module * Replace attention_mask by key_padding_mask * Remove OmDetTurboModel and refactor * Refactor processing of classes and abstract use of timm backbone * Add testing, fix output attentions and hidden states, add cache for anchors generation * Fix copies, style, quality * Add documentation, conver key_padding_mask to attention_mask * revert changes to backbone_utils * Fic docstrings rst * Fix unused argument in config * Fix image link documentation * Reorder config and cleanup * Add tokenizer_init_kwargs in merge_kwargs of the processor * Change AutoTokenizer to CLIPTokenizer in convert * Fix init_weights * Add ProcessorMixin tests, Fix convert while waiting on uniform kwargs * change processor kwargs and make task input optional * Fix omdet docs * Remove unnecessary tests for processor kwargs * Replace nested BatchEncoding output of the processor by a flattened BatchFeature * Make modifications from Pavel review * Add changes Amy review * Remove unused param * Remove normalize_before param, Modify processor call docstring * Remove redundant decoder class, add gradient checkpointing for decoder * Remove commented out code * Fix inference in fp16 and add fp16 integrated test * update omdet md doc * Add OmdetTurboModel * fix caching and nit * add OmDetTurboModel to tests * nit change repeated key test * Improve inference speed in eager mode * fix copies * Fix nit * remove OmdetTurboModel * [run-slow] omdet_turbo * [run-slow] omdet_turbo * skip dataparallel test * [run-slow] omdet_turbo * update weights to new path * remove unnecessary config in class --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-91-248.ec2.internal>
This commit is contained in:
parent
ade9e0fe41
commit
94f18cf23c
@ -862,6 +862,8 @@
|
||||
title: MGP-STR
|
||||
- local: model_doc/nougat
|
||||
title: Nougat
|
||||
- local: model_doc/omdet-turbo
|
||||
title: OmDet-Turbo
|
||||
- local: model_doc/oneformer
|
||||
title: OneFormer
|
||||
- local: model_doc/owlvit
|
||||
|
@ -237,6 +237,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [Nyströmformer](model_doc/nystromformer) | ✅ | ❌ | ❌ |
|
||||
| [OLMo](model_doc/olmo) | ✅ | ❌ | ❌ |
|
||||
| [OLMoE](model_doc/olmoe) | ✅ | ❌ | ❌ |
|
||||
| [OmDet-Turbo](model_doc/omdet-turbo) | ✅ | ❌ | ❌ |
|
||||
| [OneFormer](model_doc/oneformer) | ✅ | ❌ | ❌ |
|
||||
| [OpenAI GPT](model_doc/openai-gpt) | ✅ | ✅ | ❌ |
|
||||
| [OpenAI GPT-2](model_doc/gpt2) | ✅ | ✅ | ✅ |
|
||||
|
164
docs/source/en/model_doc/omdet-turbo.md
Normal file
164
docs/source/en/model_doc/omdet-turbo.md
Normal file
@ -0,0 +1,164 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# OmDet-Turbo
|
||||
|
||||
## Overview
|
||||
|
||||
The OmDet-Turbo model was proposed in [Real-time Transformer-based Open-Vocabulary Detection with Efficient Fusion Head](https://arxiv.org/abs/2403.06892) by Tiancheng Zhao, Peng Liu, Xuan He, Lu Zhang, Kyusong Lee. OmDet-Turbo incorporates components from RT-DETR and introduces a swift multimodal fusion module to achieve real-time open-vocabulary object detection capabilities while maintaining high accuracy. The base model achieves performance of up to 100.2 FPS and 53.4 AP on COCO zero-shot.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*End-to-end transformer-based detectors (DETRs) have shown exceptional performance in both closed-set and open-vocabulary object detection (OVD) tasks through the integration of language modalities. However, their demanding computational requirements have hindered their practical application in real-time object detection (OD) scenarios. In this paper, we scrutinize the limitations of two leading models in the OVDEval benchmark, OmDet and Grounding-DINO, and introduce OmDet-Turbo. This novel transformer-based real-time OVD model features an innovative Efficient Fusion Head (EFH) module designed to alleviate the bottlenecks observed in OmDet and Grounding-DINO. Notably, OmDet-Turbo-Base achieves a 100.2 frames per second (FPS) with TensorRT and language cache techniques applied. Notably, in zero-shot scenarios on COCO and LVIS datasets, OmDet-Turbo achieves performance levels nearly on par with current state-of-the-art supervised models. Furthermore, it establishes new state-of-the-art benchmarks on ODinW and OVDEval, boasting an AP of 30.1 and an NMS-AP of 26.86, respectively. The practicality of OmDet-Turbo in industrial applications is underscored by its exceptional performance on benchmark datasets and superior inference speed, positioning it as a compelling choice for real-time object detection tasks.*
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/omdet_turbo_architecture.jpeg" alt="drawing" width="600"/>
|
||||
|
||||
<small> OmDet-Turbo architecture overview. Taken from the <a href="https://arxiv.org/abs/2403.06892">original paper</a>. </small>
|
||||
|
||||
This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan).
|
||||
The original code can be found [here](https://github.com/om-ai-lab/OmDet).
|
||||
|
||||
## Usage tips
|
||||
|
||||
One unique property of OmDet-Turbo compared to other zero-shot object detection models, such as [Grounding DINO](grounding-dino), is the decoupled classes and prompt embedding structure that allows caching of text embeddings. This means that the model needs both classes and task as inputs, where classes is a list of objects we want to detect and task is the grounded text used to guide open-vocabulary detection. This approach limits the scope of the open-vocabulary detection and makes the decoding process faster.
|
||||
|
||||
[`OmDetTurboProcessor`] is used to prepare the classes, task and image triplet. The task input is optional, and when not provided, it will default to `"Detect [class1], [class2], [class3], ..."`. To process the results from the model, one can use `post_process_grounded_object_detection` from [`OmDetTurboProcessor`]. Notably, this function takes in the input classes, as unlike other zero-shot object detection models, the decoupling of classes and task embeddings means that no decoding of the predicted class embeddings is needed in the post-processing step, and the predicted classes can be matched to the inputted ones directly.
|
||||
|
||||
## Usage example
|
||||
|
||||
### Single image inference
|
||||
|
||||
Here's how to load the model and prepare the inputs to perform zero-shot object detection on a single image:
|
||||
|
||||
```python
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoProcessor, OmDetTurboForObjectDetection
|
||||
|
||||
processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-tiny")
|
||||
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-tiny")
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
classes = ["cat", "remote"]
|
||||
inputs = processor(image, text=classes, return_tensors="pt")
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
# convert outputs (bounding boxes and class logits)
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs,
|
||||
classes=classes,
|
||||
target_sizes=[image.size[::-1]],
|
||||
score_threshold=0.3,
|
||||
nms_threshold=0.3,
|
||||
)[0]
|
||||
for score, class_name, box in zip(
|
||||
results["scores"], results["classes"], results["boxes"]
|
||||
):
|
||||
box = [round(i, 1) for i in box.tolist()]
|
||||
print(
|
||||
f"Detected {class_name} with confidence "
|
||||
f"{round(score.item(), 2)} at location {box}"
|
||||
)
|
||||
```
|
||||
|
||||
### Multi image inference
|
||||
|
||||
OmDet-Turbo can perform batched multi-image inference, with support for different text prompts and classes in the same batch:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> from io import BytesIO
|
||||
>>> from PIL import Image
|
||||
>>> from transformers import AutoProcessor, OmDetTurboForObjectDetection
|
||||
|
||||
>>> processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
|
||||
>>> model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
|
||||
|
||||
>>> url1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image1 = Image.open(BytesIO(requests.get(url1).content)).convert("RGB")
|
||||
>>> classes1 = ["cat", "remote"]
|
||||
>>> task1 = "Detect {}.".format(", ".join(classes1))
|
||||
|
||||
>>> url2 = "http://images.cocodataset.org/train2017/000000257813.jpg"
|
||||
>>> image2 = Image.open(BytesIO(requests.get(url2).content)).convert("RGB")
|
||||
>>> classes2 = ["boat"]
|
||||
>>> task2 = "Detect everything that looks like a boat."
|
||||
|
||||
>>> url3 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||
>>> image3 = Image.open(BytesIO(requests.get(url3).content)).convert("RGB")
|
||||
>>> classes3 = ["statue", "trees"]
|
||||
>>> task3 = "Focus on the foreground, detect statue and trees."
|
||||
|
||||
>>> inputs = processor(
|
||||
... images=[image1, image2, image3],
|
||||
... text=[classes1, classes2, classes3],
|
||||
... task=[task1, task2, task3],
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> # convert outputs (bounding boxes and class logits)
|
||||
>>> results = processor.post_process_grounded_object_detection(
|
||||
... outputs,
|
||||
... classes=[classes1, classes2, classes3],
|
||||
... target_sizes=[image1.size[::-1], image2.size[::-1], image3.size[::-1]],
|
||||
... score_threshold=0.2,
|
||||
... nms_threshold=0.3,
|
||||
... )
|
||||
|
||||
>>> for i, result in enumerate(results):
|
||||
... for score, class_name, box in zip(
|
||||
... result["scores"], result["classes"], result["boxes"]
|
||||
... ):
|
||||
... box = [round(i, 1) for i in box.tolist()]
|
||||
... print(
|
||||
... f"Detected {class_name} with confidence "
|
||||
... f"{round(score.item(), 2)} at location {box} in image {i}"
|
||||
... )
|
||||
Detected remote with confidence 0.77 at location [39.9, 70.4, 176.7, 118.0] in image 0
|
||||
Detected cat with confidence 0.72 at location [11.6, 54.2, 314.8, 474.0] in image 0
|
||||
Detected remote with confidence 0.56 at location [333.4, 75.8, 370.7, 187.0] in image 0
|
||||
Detected cat with confidence 0.55 at location [345.2, 24.0, 639.8, 371.7] in image 0
|
||||
Detected boat with confidence 0.32 at location [146.9, 219.8, 209.6, 250.7] in image 1
|
||||
Detected boat with confidence 0.3 at location [319.1, 223.2, 403.2, 238.4] in image 1
|
||||
Detected boat with confidence 0.27 at location [37.7, 220.3, 84.0, 235.9] in image 1
|
||||
Detected boat with confidence 0.22 at location [407.9, 207.0, 441.7, 220.2] in image 1
|
||||
Detected statue with confidence 0.73 at location [544.7, 210.2, 651.9, 502.8] in image 2
|
||||
Detected trees with confidence 0.25 at location [3.9, 584.3, 391.4, 785.6] in image 2
|
||||
Detected trees with confidence 0.25 at location [1.4, 621.2, 118.2, 787.8] in image 2
|
||||
Detected statue with confidence 0.2 at location [428.1, 205.5, 767.3, 759.5] in image 2
|
||||
|
||||
```
|
||||
|
||||
## OmDetTurboConfig
|
||||
|
||||
[[autodoc]] OmDetTurboConfig
|
||||
|
||||
## OmDetTurboProcessor
|
||||
|
||||
[[autodoc]] OmDetTurboProcessor
|
||||
- post_process_grounded_object_detection
|
||||
|
||||
## OmDetTurboForObjectDetection
|
||||
|
||||
[[autodoc]] OmDetTurboForObjectDetection
|
||||
- forward
|
@ -609,6 +609,10 @@ _import_structure = {
|
||||
"models.nystromformer": ["NystromformerConfig"],
|
||||
"models.olmo": ["OlmoConfig"],
|
||||
"models.olmoe": ["OlmoeConfig"],
|
||||
"models.omdet_turbo": [
|
||||
"OmDetTurboConfig",
|
||||
"OmDetTurboProcessor",
|
||||
],
|
||||
"models.oneformer": [
|
||||
"OneFormerConfig",
|
||||
"OneFormerProcessor",
|
||||
@ -2861,6 +2865,12 @@ else:
|
||||
"OlmoePreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.omdet_turbo"].extend(
|
||||
[
|
||||
"OmDetTurboForObjectDetection",
|
||||
"OmDetTurboPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.oneformer"].extend(
|
||||
[
|
||||
"OneFormerForUniversalSegmentation",
|
||||
@ -5407,6 +5417,10 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.olmo import OlmoConfig
|
||||
from .models.olmoe import OlmoeConfig
|
||||
from .models.omdet_turbo import (
|
||||
OmDetTurboConfig,
|
||||
OmDetTurboProcessor,
|
||||
)
|
||||
from .models.oneformer import (
|
||||
OneFormerConfig,
|
||||
OneFormerProcessor,
|
||||
@ -7383,6 +7397,10 @@ if TYPE_CHECKING:
|
||||
OlmoeModel,
|
||||
OlmoePreTrainedModel,
|
||||
)
|
||||
from .models.omdet_turbo import (
|
||||
OmDetTurboForObjectDetection,
|
||||
OmDetTurboPreTrainedModel,
|
||||
)
|
||||
from .models.oneformer import (
|
||||
OneFormerForUniversalSegmentation,
|
||||
OneFormerModel,
|
||||
|
@ -173,6 +173,7 @@ from . import (
|
||||
nystromformer,
|
||||
olmo,
|
||||
olmoe,
|
||||
omdet_turbo,
|
||||
oneformer,
|
||||
openai,
|
||||
opt,
|
||||
|
@ -60,6 +60,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("chinese_clip_vision_model", "ChineseCLIPVisionConfig"),
|
||||
("clap", "ClapConfig"),
|
||||
("clip", "CLIPConfig"),
|
||||
("clip_text_model", "CLIPTextConfig"),
|
||||
("clip_vision_model", "CLIPVisionConfig"),
|
||||
("clipseg", "CLIPSegConfig"),
|
||||
("clvp", "ClvpConfig"),
|
||||
@ -191,6 +192,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("nystromformer", "NystromformerConfig"),
|
||||
("olmo", "OlmoConfig"),
|
||||
("olmoe", "OlmoeConfig"),
|
||||
("omdet-turbo", "OmDetTurboConfig"),
|
||||
("oneformer", "OneFormerConfig"),
|
||||
("open-llama", "OpenLlamaConfig"),
|
||||
("openai-gpt", "OpenAIGPTConfig"),
|
||||
@ -348,6 +350,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
|
||||
("clap", "CLAP"),
|
||||
("clip", "CLIP"),
|
||||
("clip_text_model", "CLIPTextModel"),
|
||||
("clip_vision_model", "CLIPVisionModel"),
|
||||
("clipseg", "CLIPSeg"),
|
||||
("clvp", "CLVP"),
|
||||
@ -497,6 +500,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("nystromformer", "Nyströmformer"),
|
||||
("olmo", "OLMo"),
|
||||
("olmoe", "OLMoE"),
|
||||
("omdet-turbo", "OmDet-Turbo"),
|
||||
("oneformer", "OneFormer"),
|
||||
("open-llama", "OpenLlama"),
|
||||
("openai-gpt", "OpenAI GPT"),
|
||||
@ -665,6 +669,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
||||
("xclip", "x_clip"),
|
||||
("clip_vision_model", "clip"),
|
||||
("qwen2_audio_encoder", "qwen2_audio"),
|
||||
("clip_text_model", "clip"),
|
||||
("siglip_vision_model", "siglip"),
|
||||
("chinese_clip_vision_model", "chinese_clip"),
|
||||
("rt_detr_resnet", "rt_detr"),
|
||||
|
@ -60,6 +60,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
|
||||
("clap", "ClapModel"),
|
||||
("clip", "CLIPModel"),
|
||||
("clip_text_model", "CLIPTextModel"),
|
||||
("clip_vision_model", "CLIPVisionModel"),
|
||||
("clipseg", "CLIPSegModel"),
|
||||
("clvp", "ClvpModelForConditionalGeneration"),
|
||||
@ -181,6 +182,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("nystromformer", "NystromformerModel"),
|
||||
("olmo", "OlmoModel"),
|
||||
("olmoe", "OlmoeModel"),
|
||||
("omdet-turbo", "OmDetTurboForObjectDetection"),
|
||||
("oneformer", "OneFormerModel"),
|
||||
("open-llama", "OpenLlamaModel"),
|
||||
("openai-gpt", "OpenAIGPTModel"),
|
||||
@ -812,6 +814,7 @@ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Zero Shot Object Detection mapping
|
||||
("grounding-dino", "GroundingDinoForObjectDetection"),
|
||||
("omdet-turbo", "OmDetTurboForObjectDetection"),
|
||||
("owlv2", "Owlv2ForObjectDetection"),
|
||||
("owlvit", "OwlViTForObjectDetection"),
|
||||
]
|
||||
@ -1326,6 +1329,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
||||
("albert", "AlbertModel"),
|
||||
("bert", "BertModel"),
|
||||
("big_bird", "BigBirdModel"),
|
||||
("clip_text_model", "CLIPTextModel"),
|
||||
("data2vec-text", "Data2VecTextModel"),
|
||||
("deberta", "DebertaModel"),
|
||||
("deberta-v2", "DebertaV2Model"),
|
||||
|
@ -344,6 +344,10 @@ else:
|
||||
),
|
||||
("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"omdet-turbo",
|
||||
("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None),
|
||||
),
|
||||
("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"openai-gpt",
|
||||
|
56
src/transformers/models/omdet_turbo/__init__.py
Normal file
56
src/transformers/models/omdet_turbo/__init__.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_omdet_turbo": ["OmDetTurboConfig"],
|
||||
"processing_omdet_turbo": ["OmDetTurboProcessor"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_omdet_turbo"] = [
|
||||
"OmDetTurboForObjectDetection",
|
||||
"OmDetTurboPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_omdet_turbo import (
|
||||
OmDetTurboConfig,
|
||||
)
|
||||
from .processing_omdet_turbo import OmDetTurboProcessor
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_omdet_turbo import (
|
||||
OmDetTurboForObjectDetection,
|
||||
OmDetTurboPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
290
src/transformers/models/omdet_turbo/configuration_omdet_turbo.py
Normal file
290
src/transformers/models/omdet_turbo/configuration_omdet_turbo.py
Normal file
@ -0,0 +1,290 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""OmDet-Turbo model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import verify_backbone_config_arguments
|
||||
from ..auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class OmDetTurboConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`OmDetTurboForObjectDetection`].
|
||||
It is used to instantiate a OmDet-Turbo model according to the specified arguments, defining the model architecture
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to that of the OmDet-Turbo
|
||||
[omlab/omdet-turbo-swin-tiny-hf](https://huggingface.co/omlab/omdet-turbo-swin-tiny-hf) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
text_config (`PretrainedConfig`, *optional*):
|
||||
The configuration of the text backbone.
|
||||
backbone_config (`PretrainedConfig`, *optional*):
|
||||
The configuration of the vision backbone.
|
||||
use_timm_backbone (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the timm for the vision backbone.
|
||||
backbone (`str`, *optional*, defaults to `"swin_tiny_patch4_window7_224"`):
|
||||
The name of the pretrained vision backbone to use. If `use_pretrained_backbone=False` a randomly initialized
|
||||
backbone with the same architecture `backbone` is used.
|
||||
backbone_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for the vision backbone.
|
||||
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a pretrained vision backbone.
|
||||
apply_layernorm_after_vision_backbone (`bool`, *optional*, defaults to `True`):
|
||||
Whether to apply layer normalization on the feature maps of the vision backbone output.
|
||||
image_size (`int`, *optional*, defaults to 640):
|
||||
The size (resolution) of each image.
|
||||
disable_custom_kernels (`bool`, *optional*, defaults to `False`):
|
||||
Whether to disable custom kernels.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon value for layer normalization.
|
||||
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon value for batch normalization.
|
||||
init_std (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
text_projection_in_dim (`int`, *optional*, defaults to 512):
|
||||
The input dimension for the text projection.
|
||||
text_projection_out_dim (`int`, *optional*, defaults to 512):
|
||||
The output dimension for the text projection.
|
||||
task_encoder_hidden_dim (`int`, *optional*, defaults to 1024):
|
||||
The feedforward dimension for the task encoder.
|
||||
class_embed_dim (`int`, *optional*, defaults to 512):
|
||||
The dimension of the classes embeddings.
|
||||
class_distance_type (`str`, *optional*, defaults to `"cosine"`):
|
||||
The type of of distance to compare predicted classes to projected classes embeddings.
|
||||
Can be `"cosine"` or `"dot"`.
|
||||
num_queries (`int`, *optional*, defaults to 900):
|
||||
The number of queries.
|
||||
csp_activation (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function of the Cross Stage Partial (CSP) networks of the encoder.
|
||||
conv_norm_activation (`str`, *optional*, defaults to `"gelu"`):
|
||||
The activation function of the ConvNormLayer layers of the encoder.
|
||||
encoder_feedforward_activation (`str`, *optional*, defaults to `"relu"`):
|
||||
The activation function for the feedforward network of the encoder.
|
||||
encoder_feedforward_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout rate following the activation of the encoder feedforward network.
|
||||
encoder_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout rate of the encoder multi-head attention module.
|
||||
hidden_expansion (`int`, *optional*, defaults to 1):
|
||||
The hidden expansion of the CSP networks in the encoder.
|
||||
vision_features_channels (`tuple(int)`, *optional*, defaults to `[256, 256, 256]`):
|
||||
The projected vision features channels used as inputs for the decoder.
|
||||
encoder_hidden_dim (`int`, *optional*, defaults to 256):
|
||||
The hidden dimension of the encoder.
|
||||
encoder_in_channels (`List(int)`, *optional*, defaults to `[192, 384, 768]`):
|
||||
The input channels for the encoder.
|
||||
encoder_projection_indices (`List(int)`, *optional*, defaults to `[2]`):
|
||||
The indices of the input features projected by each layers.
|
||||
encoder_attention_heads (`int`, *optional*, defaults to 8):
|
||||
The number of attention heads for the encoder.
|
||||
encoder_dim_feedforward (`int`, *optional*, defaults to 2048):
|
||||
The feedforward dimension for the encoder.
|
||||
encoder_layers (`int`, *optional*, defaults to 1):
|
||||
The number of layers in the encoder.
|
||||
positional_encoding_temperature (`int`, *optional*, defaults to 10000):
|
||||
The positional encoding temperature in the encoder.
|
||||
num_feature_levels (`int`, *optional*, defaults to 3):
|
||||
The number of feature levels for the multi-scale deformable attention module of the decoder.
|
||||
decoder_hidden_dim (`int`, *optional*, defaults to 256):
|
||||
The hidden dimension of the decoder.
|
||||
decoder_num_heads (`int`, *optional*, defaults to 8):
|
||||
The number of heads for the decoder.
|
||||
decoder_num_layers (`int`, *optional*, defaults to 6):
|
||||
The number of layers for the decoder.
|
||||
decoder_activation (`str`, *optional*, defaults to `"relu"`):
|
||||
The activation function for the decoder.
|
||||
decoder_dim_feedforward (`int`, *optional*, defaults to 2048):
|
||||
The feedforward dimension for the decoder.
|
||||
decoder_num_points (`int`, *optional*, defaults to 4):
|
||||
The number of points sampled in the decoder multi-scale deformable attention module.
|
||||
decoder_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout rate for the decoder.
|
||||
eval_size (`Tuple[int, int]`, *optional*):
|
||||
Height and width used to computes the effective height and width of the position embeddings after taking
|
||||
into account the stride (see RTDetr).
|
||||
learn_initial_query (`bool`, *optional*, defaults to `False`):
|
||||
Whether to learn the initial query.
|
||||
cache_size (`int`, *optional*, defaults to 100):
|
||||
The cache size for the classes and prompts caches.
|
||||
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
||||
Whether the model is used as an encoder-decoder model or not.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional parameters from the architecture. The values in kwargs will be saved as part of the configuration
|
||||
and can be used to control the model outputs.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import OmDetTurboConfig, OmDetTurboForObjectDetection
|
||||
|
||||
>>> # Initializing a OmDet-Turbo omlab/omdet-turbo-tiny style configuration
|
||||
>>> configuration = OmDetTurboConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the omlab/omdet-turbo-tiny style configuration
|
||||
>>> model = OmDetTurboForObjectDetection(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "omdet-turbo"
|
||||
attribute_map = {
|
||||
"encoder_hidden_dim": "d_model",
|
||||
"num_attention_heads": "encoder_attention_heads",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
backbone_config=None,
|
||||
use_timm_backbone=True,
|
||||
backbone="swin_tiny_patch4_window7_224",
|
||||
backbone_kwargs=None,
|
||||
use_pretrained_backbone=False,
|
||||
apply_layernorm_after_vision_backbone=True,
|
||||
image_size=640,
|
||||
disable_custom_kernels=False,
|
||||
layer_norm_eps=1e-5,
|
||||
batch_norm_eps=1e-5,
|
||||
init_std=0.02,
|
||||
text_projection_in_dim=512,
|
||||
text_projection_out_dim=512,
|
||||
task_encoder_hidden_dim=1024,
|
||||
class_embed_dim=512,
|
||||
class_distance_type="cosine",
|
||||
num_queries=900,
|
||||
csp_activation="silu",
|
||||
conv_norm_activation="gelu",
|
||||
encoder_feedforward_activation="relu",
|
||||
encoder_feedforward_dropout=0.0,
|
||||
encoder_dropout=0.0,
|
||||
hidden_expansion=1,
|
||||
vision_features_channels=[256, 256, 256],
|
||||
encoder_hidden_dim=256,
|
||||
encoder_in_channels=[192, 384, 768],
|
||||
encoder_projection_indices=[2],
|
||||
encoder_attention_heads=8,
|
||||
encoder_dim_feedforward=2048,
|
||||
encoder_layers=1,
|
||||
positional_encoding_temperature=10000,
|
||||
num_feature_levels=3,
|
||||
decoder_hidden_dim=256,
|
||||
decoder_num_heads=8,
|
||||
decoder_num_layers=6,
|
||||
decoder_activation="relu",
|
||||
decoder_dim_feedforward=2048,
|
||||
decoder_num_points=4,
|
||||
decoder_dropout=0.0,
|
||||
eval_size=None,
|
||||
learn_initial_query=False,
|
||||
cache_size=100,
|
||||
is_encoder_decoder=True,
|
||||
**kwargs,
|
||||
):
|
||||
if use_timm_backbone:
|
||||
if backbone_config is None:
|
||||
backbone_kwargs = {
|
||||
"out_indices": [1, 2, 3],
|
||||
"img_size": image_size,
|
||||
"always_partition": True,
|
||||
}
|
||||
elif backbone_config is None:
|
||||
logger.info("`backbone_config` is `None`. Initializing the config with the default `swin` vision config.")
|
||||
backbone_config = CONFIG_MAPPING["swin"](
|
||||
window_size=7,
|
||||
image_size=image_size,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
out_indices=[2, 3, 4],
|
||||
)
|
||||
elif isinstance(backbone_config, dict):
|
||||
backbone_model_type = backbone_config.get("model_type")
|
||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||
backbone_config = config_class.from_dict(backbone_config)
|
||||
|
||||
verify_backbone_config_arguments(
|
||||
use_timm_backbone=use_timm_backbone,
|
||||
use_pretrained_backbone=use_pretrained_backbone,
|
||||
backbone=backbone,
|
||||
backbone_config=backbone_config,
|
||||
backbone_kwargs=backbone_kwargs,
|
||||
)
|
||||
|
||||
if text_config is None:
|
||||
logger.info(
|
||||
"`text_config` is `None`. Initializing the config with the default `clip_text_model` text config."
|
||||
)
|
||||
text_config = CONFIG_MAPPING["clip_text_model"]()
|
||||
elif isinstance(text_config, dict):
|
||||
text_model_type = text_config.get("model_type")
|
||||
text_config = CONFIG_MAPPING[text_model_type](**text_config)
|
||||
|
||||
if class_distance_type not in ["cosine", "dot"]:
|
||||
raise ValueError(
|
||||
f"Invalid `class_distance_type`. It should be either `cosine` or `dot`, but got {class_distance_type}."
|
||||
)
|
||||
|
||||
self.text_config = text_config
|
||||
self.backbone_config = backbone_config
|
||||
self.use_timm_backbone = use_timm_backbone
|
||||
self.backbone = backbone
|
||||
self.backbone_kwargs = backbone_kwargs
|
||||
self.use_pretrained_backbone = use_pretrained_backbone
|
||||
self.apply_layernorm_after_vision_backbone = apply_layernorm_after_vision_backbone
|
||||
self.image_size = image_size
|
||||
self.disable_custom_kernels = disable_custom_kernels
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.batch_norm_eps = batch_norm_eps
|
||||
self.init_std = init_std
|
||||
self.text_projection_in_dim = text_projection_in_dim
|
||||
self.text_projection_out_dim = text_projection_out_dim
|
||||
self.task_encoder_hidden_dim = task_encoder_hidden_dim
|
||||
self.class_embed_dim = class_embed_dim
|
||||
self.class_distance_type = class_distance_type
|
||||
self.num_queries = num_queries
|
||||
self.csp_activation = csp_activation
|
||||
self.conv_norm_activation = conv_norm_activation
|
||||
self.encoder_feedforward_activation = encoder_feedforward_activation
|
||||
self.encoder_feedforward_dropout = encoder_feedforward_dropout
|
||||
self.encoder_dropout = encoder_dropout
|
||||
self.hidden_expansion = hidden_expansion
|
||||
self.vision_features_channels = vision_features_channels
|
||||
self.encoder_hidden_dim = encoder_hidden_dim
|
||||
self.encoder_in_channels = encoder_in_channels
|
||||
self.encoder_projection_indices = encoder_projection_indices
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.encoder_dim_feedforward = encoder_dim_feedforward
|
||||
self.encoder_layers = encoder_layers
|
||||
self.positional_encoding_temperature = positional_encoding_temperature
|
||||
self.num_feature_levels = num_feature_levels
|
||||
self.decoder_hidden_dim = decoder_hidden_dim
|
||||
self.decoder_num_heads = decoder_num_heads
|
||||
self.decoder_num_layers = decoder_num_layers
|
||||
self.decoder_activation = decoder_activation
|
||||
self.decoder_dim_feedforward = decoder_dim_feedforward
|
||||
self.decoder_num_points = decoder_num_points
|
||||
self.decoder_dropout = decoder_dropout
|
||||
self.eval_size = eval_size
|
||||
self.learn_initial_query = learn_initial_query
|
||||
self.cache_size = cache_size
|
||||
self.is_encoder_decoder = is_encoder_decoder
|
||||
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
349
src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py
Normal file
349
src/transformers/models/omdet_turbo/convert_omdet_turbo_to_hf.py
Normal file
@ -0,0 +1,349 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert OmDet-Turbo checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/om-ai-lab/OmDet"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
CLIPTokenizer,
|
||||
DetrImageProcessor,
|
||||
OmDetTurboConfig,
|
||||
OmDetTurboForObjectDetection,
|
||||
OmDetTurboProcessor,
|
||||
)
|
||||
|
||||
|
||||
IMAGE_MEAN = [123.675, 116.28, 103.53]
|
||||
IMAGE_STD = [58.395, 57.12, 57.375]
|
||||
|
||||
|
||||
def get_omdet_turbo_config(model_name, use_timm_backbone):
|
||||
if "tiny" in model_name:
|
||||
window_size = 7
|
||||
embed_dim = 96
|
||||
depths = (2, 2, 6, 2)
|
||||
num_heads = (3, 6, 12, 24)
|
||||
image_size = 640
|
||||
else:
|
||||
raise ValueError("Model not supported, only supports tiny variant.")
|
||||
|
||||
config = OmDetTurboConfig(
|
||||
backbone_window_size=window_size,
|
||||
backbone_image_size=image_size,
|
||||
backbone_embed_dim=embed_dim,
|
||||
backbone_depths=depths,
|
||||
backbone_num_heads=num_heads,
|
||||
backbone_out_indices=(1, 2, 3),
|
||||
text_config={"model_type": "clip_text_model"},
|
||||
use_timm_backbone=use_timm_backbone,
|
||||
backbone="swin_tiny_patch4_window7_224" if use_timm_backbone else None,
|
||||
apply_layernorm_after_vision_backbone=True if use_timm_backbone else False,
|
||||
use_pretrained_backbone=False,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_rename_keys_vision(state_dict, config):
|
||||
rename_keys = []
|
||||
# fmt: off
|
||||
########################################## VISION BACKBONE - START
|
||||
for layer_name in state_dict.keys():
|
||||
if layer_name.startswith("backbone") and not layer_name.startswith("backbone.norm"):
|
||||
if config.use_timm_backbone:
|
||||
layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone._backbone")
|
||||
layer_name_replace = layer_name_replace.replace(".layers.", ".layers_")
|
||||
if "downsample" in layer_name:
|
||||
# get layer number
|
||||
layer_num = int(layer_name.split(".")[2])
|
||||
layer_name_replace = layer_name_replace.replace(f"{layer_num}.downsample", f"{layer_num+1}.downsample")
|
||||
else:
|
||||
layer_name_replace = layer_name.replace("backbone", "vision_backbone.vision_backbone")
|
||||
layer_name_replace = layer_name_replace.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
|
||||
layer_name_replace = layer_name_replace.replace("patch_embed.norm", "embeddings.norm")
|
||||
if layer_name.startswith("backbone.layers"):
|
||||
layer_name_replace = layer_name_replace.replace("norm1", "layernorm_before")
|
||||
layer_name_replace = layer_name_replace.replace("norm2", "layernorm_after")
|
||||
layer_name_replace = layer_name_replace.replace("attn.proj", "attention.output.dense")
|
||||
layer_name_replace = layer_name_replace.replace("mlp.fc1", "intermediate.dense")
|
||||
layer_name_replace = layer_name_replace.replace("mlp.fc2", "output.dense")
|
||||
layer_name_replace = layer_name_replace.replace(".layers.", ".encoder.layers.")
|
||||
layer_name_replace = layer_name_replace.replace(".attn.", ".attention.self.")
|
||||
elif layer_name.startswith("backbone.norm"):
|
||||
layer_num = int(layer_name.split("norm")[1].split(".")[0])
|
||||
if config.use_timm_backbone:
|
||||
layer_name_replace = layer_name.replace("backbone", "vision_backbone")
|
||||
layer_name_replace = layer_name_replace.replace(f"norm{layer_num}", f"layer_norms.{layer_num-1}")
|
||||
else:
|
||||
layer_name_replace = layer_name.replace(f"backbone.norm{layer_num}", f"vision_backbone.vision_backbone.hidden_states_norms.stage{layer_num+1}")
|
||||
else:
|
||||
continue
|
||||
rename_keys.append((layer_name, layer_name_replace))
|
||||
########################################## VISION BACKBONE - END
|
||||
|
||||
########################################## ENCODER - START
|
||||
for layer_name, params in state_dict.items():
|
||||
if "neck" in layer_name:
|
||||
layer_name_replace = layer_name.replace("neck", "encoder")
|
||||
layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers")
|
||||
if "fpn_blocks" in layer_name or "pan_blocks" in layer_name or "lateral_convs" in layer_name or "downsample_convs" in layer_name:
|
||||
layer_name_replace = layer_name_replace.replace(".m.", ".bottlenecks.")
|
||||
layer_name_replace = layer_name_replace.replace(".cv", ".conv")
|
||||
layer_name_replace = layer_name_replace.replace(".bn", ".norm")
|
||||
if "encoder_layer" in layer_name:
|
||||
layer_name_replace = layer_name_replace.replace("encoder_layer", "encoder.0.layers.0")
|
||||
layer_name_replace = layer_name_replace.replace(".linear", ".fc")
|
||||
layer_name_replace = layer_name_replace.replace("norm1", "self_attn_layer_norm")
|
||||
layer_name_replace = layer_name_replace.replace("norm2", "final_layer_norm")
|
||||
rename_keys.append((layer_name, layer_name_replace))
|
||||
########################################## ENCODER - END
|
||||
|
||||
########################################## DECODER - START
|
||||
for layer_name, params in state_dict.items():
|
||||
if layer_name.startswith("decoder"):
|
||||
layer_name_replace = layer_name.replace("decoder.decoder.layers", "decoder.layers")
|
||||
layer_name_replace = layer_name_replace.replace("input_proj", "channel_projection_layers")
|
||||
layer_name_replace = layer_name_replace.replace("query_pos_head", "query_position_head")
|
||||
layer_name_replace = layer_name_replace.replace("enc_bbox_head", "encoder_bbox_head")
|
||||
layer_name_replace = layer_name_replace.replace("enc_output", "encoder_vision_features")
|
||||
layer_name_replace = layer_name_replace.replace("dec_score_head", "decoder_class_head")
|
||||
layer_name_replace = layer_name_replace.replace("dec_bbox_head", "decoder_bbox_head")
|
||||
layer_name_replace = layer_name_replace.replace("enc_score_head", "encoder_class_head")
|
||||
rename_keys.append((layer_name, layer_name_replace))
|
||||
########################################## DECODER - END
|
||||
# fmt: on
|
||||
return rename_keys
|
||||
|
||||
|
||||
def create_rename_keys_language(state_dict):
|
||||
rename_keys = []
|
||||
# fmt: off
|
||||
for layer_name in state_dict.keys():
|
||||
if layer_name.startswith("language_backbone") and not layer_name.startswith("language_backbone.text_projection"):
|
||||
layer_name_replace = layer_name.replace("language_backbone", "language_backbone.model.text_model")
|
||||
layer_name_replace = layer_name_replace.replace("transformer.resblocks", "encoder.layers")
|
||||
layer_name_replace = layer_name_replace.replace("token_embedding", "embeddings.token_embedding")
|
||||
layer_name_replace = layer_name_replace.replace("positional_embedding", "embeddings.position_embedding.weight")
|
||||
layer_name_replace = layer_name_replace.replace(".attn", ".self_attn")
|
||||
layer_name_replace = layer_name_replace.replace(".mlp.c_fc", ".mlp.fc1")
|
||||
layer_name_replace = layer_name_replace.replace(".mlp.c_proj", ".mlp.fc2")
|
||||
layer_name_replace = layer_name_replace.replace("ln_final", "final_layer_norm")
|
||||
layer_name_replace = layer_name_replace.replace(".ln_", ".layer_norm")
|
||||
rename_keys.append((layer_name, layer_name_replace))
|
||||
# fmt: on
|
||||
return rename_keys
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_q_k_v_vision(state_dict, config):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
for layer_name_vision in state_dict_keys:
|
||||
if layer_name_vision.startswith("vision_backbone") and "qkv" in layer_name_vision:
|
||||
layer_num = int(layer_name_vision.split(".")[4])
|
||||
hidden_size = config.backbone_config.embed_dim * 2**layer_num
|
||||
if "weight" in layer_name_vision:
|
||||
in_proj_weight = state_dict.pop(layer_name_vision)
|
||||
state_dict[layer_name_vision.replace("qkv.weight", "key.weight")] = in_proj_weight[:hidden_size, :]
|
||||
state_dict[layer_name_vision.replace("qkv.weight", "query.weight")] = in_proj_weight[
|
||||
hidden_size : hidden_size * 2, :
|
||||
]
|
||||
state_dict[layer_name_vision.replace("qkv.weight", "value.weight")] = in_proj_weight[-hidden_size:, :]
|
||||
elif "bias" in layer_name_vision:
|
||||
in_proj_bias = state_dict.pop(layer_name_vision)
|
||||
state_dict[layer_name_vision.replace("qkv.bias", "key.bias")] = in_proj_bias[:hidden_size]
|
||||
state_dict[layer_name_vision.replace("qkv.bias", "query.bias")] = in_proj_bias[
|
||||
hidden_size : hidden_size * 2
|
||||
]
|
||||
state_dict[layer_name_vision.replace("qkv.bias", "value.bias")] = in_proj_bias[-hidden_size:]
|
||||
|
||||
|
||||
def read_in_q_k_v_text(state_dict, config):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
hidden_size = config.text_config.projection_dim
|
||||
for layer_name_text in state_dict_keys:
|
||||
if layer_name_text.startswith("language_backbone") and "in_proj" in layer_name_text:
|
||||
if "weight" in layer_name_text:
|
||||
in_proj_weight = state_dict.pop(layer_name_text)
|
||||
state_dict[layer_name_text.replace("in_proj_weight", "q_proj.weight")] = in_proj_weight[
|
||||
:hidden_size, :
|
||||
]
|
||||
state_dict[layer_name_text.replace("in_proj_weight", "k_proj.weight")] = in_proj_weight[
|
||||
hidden_size : hidden_size * 2, :
|
||||
]
|
||||
state_dict[layer_name_text.replace("in_proj_weight", "v_proj.weight")] = in_proj_weight[
|
||||
-hidden_size:, :
|
||||
]
|
||||
elif "bias" in layer_name_text:
|
||||
in_proj_bias = state_dict.pop(layer_name_text)
|
||||
state_dict[layer_name_text.replace("in_proj_bias", "q_proj.bias")] = in_proj_bias[:hidden_size]
|
||||
state_dict[layer_name_text.replace("in_proj_bias", "k_proj.bias")] = in_proj_bias[
|
||||
hidden_size : hidden_size * 2
|
||||
]
|
||||
state_dict[layer_name_text.replace("in_proj_bias", "v_proj.bias")] = in_proj_bias[-hidden_size:]
|
||||
|
||||
|
||||
def read_in_q_k_v_encoder(state_dict, config):
|
||||
embed_dim = config.encoder_hidden_dim
|
||||
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop("encoder.encoder.0.layers.0.self_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop("encoder.encoder.0.layers.0.self_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict["encoder.encoder.0.layers.0.self_attn.query.weight"] = in_proj_weight[:embed_dim, :]
|
||||
state_dict["encoder.encoder.0.layers.0.self_attn.query.bias"] = in_proj_bias[:embed_dim]
|
||||
state_dict["encoder.encoder.0.layers.0.self_attn.key.weight"] = in_proj_weight[embed_dim : embed_dim * 2, :]
|
||||
state_dict["encoder.encoder.0.layers.0.self_attn.key.bias"] = in_proj_bias[embed_dim : embed_dim * 2]
|
||||
state_dict["encoder.encoder.0.layers.0.self_attn.value.weight"] = in_proj_weight[-embed_dim:, :]
|
||||
state_dict["encoder.encoder.0.layers.0.self_attn.value.bias"] = in_proj_bias[-embed_dim:]
|
||||
|
||||
|
||||
def read_in_q_k_v_decoder(state_dict, config):
|
||||
for layer_num in range(config.decoder_num_layers):
|
||||
embed_dim = config.decoder_hidden_dim
|
||||
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"decoder.layers.{layer_num}.self_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop(f"decoder.layers.{layer_num}.self_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"decoder.layers.{layer_num}.self_attn.query.weight"] = in_proj_weight[:embed_dim, :]
|
||||
state_dict[f"decoder.layers.{layer_num}.self_attn.query.bias"] = in_proj_bias[:embed_dim]
|
||||
state_dict[f"decoder.layers.{layer_num}.self_attn.key.weight"] = in_proj_weight[embed_dim : embed_dim * 2, :]
|
||||
state_dict[f"decoder.layers.{layer_num}.self_attn.key.bias"] = in_proj_bias[embed_dim : embed_dim * 2]
|
||||
state_dict[f"decoder.layers.{layer_num}.self_attn.value.weight"] = in_proj_weight[-embed_dim:, :]
|
||||
state_dict[f"decoder.layers.{layer_num}.self_attn.value.bias"] = in_proj_bias[-embed_dim:]
|
||||
|
||||
|
||||
def run_test(model, processor):
|
||||
# We will verify our results on an image of cute cats
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
|
||||
classes = ["cat", "remote"]
|
||||
task = "Detect {}.".format(", ".join(classes))
|
||||
inputs = processor(image, text=classes, task=task, return_tensors="pt")
|
||||
|
||||
# Running forward
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
predicted_slice = outputs[1][0, :3, :3]
|
||||
print(predicted_slice)
|
||||
expected_slice = torch.tensor([[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]])
|
||||
|
||||
assert torch.allclose(predicted_slice, expected_slice, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_omdet_turbo_checkpoint(args):
|
||||
model_name = args.model_name
|
||||
pytorch_dump_folder_path = args.pytorch_dump_folder_path
|
||||
push_to_hub = args.push_to_hub
|
||||
use_timm_backbone = args.use_timm_backbone
|
||||
|
||||
checkpoint_mapping = {
|
||||
"omdet-turbo-tiny": [
|
||||
"https://huggingface.co/omlab/OmDet-Turbo_tiny_SWIN_T/resolve/main/OmDet-Turbo_tiny_SWIN_T.pth",
|
||||
"https://huggingface.co/omlab/OmDet-Turbo_tiny_SWIN_T/resolve/main/ViT-B-16.pt",
|
||||
],
|
||||
}
|
||||
# Define default OmDetTurbo configuation
|
||||
config = get_omdet_turbo_config(model_name, use_timm_backbone)
|
||||
|
||||
# Load original checkpoint
|
||||
checkpoint_url = checkpoint_mapping[model_name]
|
||||
original_state_dict_vision = torch.hub.load_state_dict_from_url(checkpoint_url[0], map_location="cpu")["model"]
|
||||
original_state_dict_vision = {k.replace("module.", ""): v for k, v in original_state_dict_vision.items()}
|
||||
|
||||
# Rename keys
|
||||
new_state_dict = original_state_dict_vision.copy()
|
||||
rename_keys_vision = create_rename_keys_vision(new_state_dict, config)
|
||||
|
||||
rename_keys_language = create_rename_keys_language(new_state_dict)
|
||||
|
||||
for src, dest in rename_keys_vision:
|
||||
rename_key(new_state_dict, src, dest)
|
||||
|
||||
for src, dest in rename_keys_language:
|
||||
rename_key(new_state_dict, src, dest)
|
||||
|
||||
if not use_timm_backbone:
|
||||
read_in_q_k_v_vision(new_state_dict, config)
|
||||
read_in_q_k_v_text(new_state_dict, config)
|
||||
read_in_q_k_v_encoder(new_state_dict, config)
|
||||
read_in_q_k_v_decoder(new_state_dict, config)
|
||||
# add "model" prefix to all keys
|
||||
new_state_dict = {f"model.{k}": v for k, v in new_state_dict.items()}
|
||||
|
||||
# Load HF model
|
||||
model = OmDetTurboForObjectDetection(config)
|
||||
model.eval()
|
||||
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
|
||||
print("Missing keys:", missing_keys)
|
||||
print("Unexpected keys:", unexpected_keys)
|
||||
|
||||
image_processor = DetrImageProcessor(
|
||||
size={"height": config.backbone_image_size, "width": config.backbone_image_size},
|
||||
do_rescale=False,
|
||||
image_mean=IMAGE_MEAN,
|
||||
image_std=IMAGE_STD,
|
||||
do_pad=False,
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
processor = OmDetTurboProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
|
||||
# end-to-end consistency test
|
||||
run_test(model, processor)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
model.push_to_hub(f"omlab/{model_name}")
|
||||
processor.push_to_hub(f"omlab/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="omdet-turbo-tiny",
|
||||
type=str,
|
||||
choices=["omdet-turbo-tiny"],
|
||||
help="Name of the OmDetTurbo model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_timm_backbone", action="store_true", help="Whether or not to use timm backbone for vision backbone."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_omdet_turbo_checkpoint(args)
|
1810
src/transformers/models/omdet_turbo/modeling_omdet_turbo.py
Normal file
1810
src/transformers/models/omdet_turbo/modeling_omdet_turbo.py
Normal file
File diff suppressed because it is too large
Load Diff
362
src/transformers/models/omdet_turbo/processing_omdet_turbo.py
Normal file
362
src/transformers/models/omdet_turbo/processing_omdet_turbo.py
Normal file
@ -0,0 +1,362 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for OmDet-Turbo.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_transforms import center_to_corners_format
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Unpack
|
||||
else:
|
||||
from typing_extensions import Unpack
|
||||
|
||||
|
||||
class OmDetTurboTextKwargs(TextKwargs, total=False):
|
||||
task: Optional[Union[str, List[str], TextInput, PreTokenizedInput]]
|
||||
|
||||
|
||||
class OmDetTurboProcessorKwargs(ProcessingKwargs, total=False):
|
||||
text_kwargs: OmDetTurboTextKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"add_special_tokens": True,
|
||||
"padding": "max_length",
|
||||
"truncation": True,
|
||||
"max_length": 77,
|
||||
"stride": 0,
|
||||
"return_overflowing_tokens": False,
|
||||
"return_special_tokens_mask": False,
|
||||
"return_offsets_mapping": False,
|
||||
"return_token_type_ids": False,
|
||||
"return_length": False,
|
||||
"verbose": True,
|
||||
"task": None,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
}
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision.ops.boxes import batched_nms
|
||||
|
||||
|
||||
def clip_boxes(box, box_size: Tuple[int, int]):
|
||||
"""
|
||||
Clip the boxes by limiting x coordinates to the range [0, width]
|
||||
and y coordinates to the range [0, height].
|
||||
|
||||
Args:
|
||||
box (Tensor): The box to be clipped.
|
||||
box_size (height, width): The clipping box's size.
|
||||
"""
|
||||
assert torch.isfinite(box).all(), "Box tensor contains infinite or NaN!"
|
||||
height, width = box_size
|
||||
x1 = box[:, 0].clamp(min=0, max=width)
|
||||
y1 = box[:, 1].clamp(min=0, max=height)
|
||||
x2 = box[:, 2].clamp(min=0, max=width)
|
||||
y2 = box[:, 3].clamp(min=0, max=height)
|
||||
box = torch.stack((x1, y1, x2, y2), dim=-1)
|
||||
|
||||
return box
|
||||
|
||||
|
||||
def compute_score(boxes):
|
||||
"""
|
||||
Compute logit scores per class for each box (proposal) and an array of class indices
|
||||
corresponding to each proposal, flattened across the proposal_num.
|
||||
The indices in `classes` will later be used to filter and match the predicted classes
|
||||
with the input class names.
|
||||
"""
|
||||
num_classes = boxes.shape[2]
|
||||
proposal_num = boxes.shape[1]
|
||||
scores = torch.sigmoid(boxes)
|
||||
classes = torch.arange(num_classes, device=boxes.device).unsqueeze(0).repeat(proposal_num, 1).flatten(0, 1)
|
||||
return scores, classes
|
||||
|
||||
|
||||
def _post_process_boxes_for_image(
|
||||
boxes: TensorType,
|
||||
scores: TensorType,
|
||||
predicted_classes: TensorType,
|
||||
classes: List[str],
|
||||
image_size: Tuple[int, int],
|
||||
num_classes: int,
|
||||
score_threshold: float,
|
||||
nms_threshold: float,
|
||||
max_num_det: int = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Filter predicted results using given thresholds and NMS.
|
||||
Args:
|
||||
boxes (torch.Tensor): A Tensor of predicted class-specific or class-agnostic
|
||||
boxes for the image. Shape : (num_queries, max_num_classes_in_batch * 4) if doing
|
||||
class-specific regression, or (num_queries, 4) if doing class-agnostic
|
||||
regression.
|
||||
scores (torch.Tensor): A Tensor of predicted class scores for the image.
|
||||
Shape : (num_queries, max_num_classes_in_batch + 1)
|
||||
predicted_classes (torch.Tensor): A Tensor of predicted classes for the image.
|
||||
Shape : (num_queries * (max_num_classes_in_batch + 1),)
|
||||
classes (List[str]): The input classes names.
|
||||
image_size (tuple): A tuple of (height, width) for the image.
|
||||
num_classes (int): The number of classes given for this image.
|
||||
score_threshold (float): Only return detections with a confidence score exceeding this
|
||||
threshold.
|
||||
nms_threshold (float): The threshold to use for box non-maximum suppression. Value in [0, 1].
|
||||
max_num_det (int, optional): The maximum number of detections to return. Default is None.
|
||||
Returns:
|
||||
dict: A dictionary the following keys:
|
||||
"boxes" (Tensor): A tensor of shape (num_filtered_objects, 4), containing the predicted boxes in (x1, y1, x2, y2) format.
|
||||
"scores" (Tensor): A tensor of shape (num_filtered_objects,), containing the predicted confidence scores for each detection.
|
||||
"classes" (List[str]): A list of strings, where each string is the predicted class for the
|
||||
corresponding detection
|
||||
"""
|
||||
proposal_num = len(boxes) if max_num_det is None else max_num_det
|
||||
scores_per_image, topk_indices = scores.flatten(0, 1).topk(proposal_num, sorted=False)
|
||||
classes_per_image = predicted_classes[topk_indices]
|
||||
box_pred_per_image = boxes.view(-1, 1, 4).repeat(1, num_classes, 1).view(-1, 4)
|
||||
box_pred_per_image = box_pred_per_image[topk_indices]
|
||||
|
||||
# Score filtering
|
||||
box_pred_per_image = center_to_corners_format(box_pred_per_image)
|
||||
box_pred_per_image = box_pred_per_image * torch.tensor(image_size[::-1]).repeat(2).to(box_pred_per_image.device)
|
||||
filter_mask = scores_per_image > score_threshold # R x K
|
||||
score_keep = filter_mask.nonzero(as_tuple=False).view(-1)
|
||||
box_pred_per_image = box_pred_per_image[score_keep]
|
||||
scores_per_image = scores_per_image[score_keep]
|
||||
classes_per_image = classes_per_image[score_keep]
|
||||
|
||||
filter_classes_mask = classes_per_image < len(classes)
|
||||
classes_keep = filter_classes_mask.nonzero(as_tuple=False).view(-1)
|
||||
box_pred_per_image = box_pred_per_image[classes_keep]
|
||||
scores_per_image = scores_per_image[classes_keep]
|
||||
classes_per_image = classes_per_image[classes_keep]
|
||||
|
||||
# NMS
|
||||
keep = batched_nms(box_pred_per_image, scores_per_image, classes_per_image, nms_threshold)
|
||||
box_pred_per_image = box_pred_per_image[keep]
|
||||
scores_per_image = scores_per_image[keep]
|
||||
classes_per_image = classes_per_image[keep]
|
||||
classes_per_image = [classes[i] for i in classes_per_image]
|
||||
|
||||
# create an instance
|
||||
result = {}
|
||||
result["boxes"] = clip_boxes(box_pred_per_image, image_size)
|
||||
result["scores"] = scores_per_image
|
||||
result["classes"] = classes_per_image
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class OmDetTurboProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a OmDet-Turbo processor which wraps a Deformable DETR image processor and an AutoTokenizer into a
|
||||
single processor.
|
||||
|
||||
[`OmDetTurboProcessor`] offers all the functionalities of [`DetrImageProcessor`] and
|
||||
[`AutoTokenizer`]. See the docstring of [`~OmDetTurboProcessor.__call__`] and [`~OmDetTurboProcessor.decode`]
|
||||
for more information.
|
||||
|
||||
Args:
|
||||
image_processor (`DetrImageProcessor`):
|
||||
An instance of [`DetrImageProcessor`]. The image processor is a required input.
|
||||
tokenizer (`AutoTokenizer`):
|
||||
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "DetrImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor, tokenizer):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[List[str], List[List[str]]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[OmDetTurboProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
This method uses [*DetrImageProcessor.__call__] method to prepare image(s) for the model, and
|
||||
[CLIPTokenizerFast.__call__] to prepare text for the model.
|
||||
|
||||
Please refer to the docstring of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255.
|
||||
text (`Union[str, List[str], List[List[str]]]`):
|
||||
The classes used to limit the scope of the open vocabulary detection. Expects a list of strings or a list
|
||||
of list of strings. Batched classes can be of different lengths.
|
||||
Examples: ["cat", "dog", "bird"], [["cat", "dog", "bird"], ["hat", "person"], ["car"]]
|
||||
Kwargs:
|
||||
task (`Union[str, List[str], TextInput, PreTokenizedInput]`):
|
||||
The grounded text used to guide open vocabulary detection. Expects a single string or a list of strings.
|
||||
Examples: "Detect a cat, a dog, and a bird.",[ "Detect everything.", "Detect trees and flowers."]
|
||||
When not provided, the default task is "Detect [class1], [class2], [class3]" etc.
|
||||
...
|
||||
"""
|
||||
if images is None or text is None:
|
||||
raise ValueError("You have to specify both `images` and `text`")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
OmDetTurboProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text = text.strip(" ").split(",")
|
||||
|
||||
if not (len(text) and isinstance(text[0], (list, tuple))):
|
||||
text = [text]
|
||||
|
||||
task = output_kwargs["text_kwargs"].pop("task", None)
|
||||
if task is None:
|
||||
task = ["Detect {}.".format(", ".join(text_single)) for text_single in text]
|
||||
elif not isinstance(task, (list, tuple)):
|
||||
task = [task]
|
||||
|
||||
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
tasks_encoding = self.tokenizer(text=task, **output_kwargs["text_kwargs"])
|
||||
|
||||
classes = text
|
||||
|
||||
classes_structure = torch.tensor([len(class_single) for class_single in classes], dtype=torch.long)
|
||||
classes_flattened = [class_single for class_batch in classes for class_single in class_batch]
|
||||
classes_encoding = self.tokenizer(text=classes_flattened, **output_kwargs["text_kwargs"])
|
||||
|
||||
encoding = BatchFeature()
|
||||
encoding.update({f"tasks_{key}": value for key, value in tasks_encoding.items()})
|
||||
encoding.update({f"classes_{key}": value for key, value in classes_encoding.items()})
|
||||
encoding.update({"classes_structure": classes_structure})
|
||||
encoding.update(encoding_image_processor)
|
||||
|
||||
return encoding
|
||||
|
||||
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_grounded_object_detection(
|
||||
self,
|
||||
outputs,
|
||||
classes: Union[List[str], List[List[str]]],
|
||||
score_threshold: float = 0.3,
|
||||
nms_threshold: float = 0.5,
|
||||
target_sizes: Optional[Union[TensorType, List[Tuple]]] = None,
|
||||
max_num_det: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Converts the raw output of [`OmDetTurboForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
|
||||
bottom_right_x, bottom_right_y) format and get the associated text class.
|
||||
|
||||
Args:
|
||||
outputs ([`OmDetTurboObjectDetectionOutput`]):
|
||||
Raw outputs of the model.
|
||||
classes (Union[List[str], List[List[str]]]): The input classes names.
|
||||
score_threshold (float, defaults to 0.3): Only return detections with a confidence score exceeding this
|
||||
threshold.
|
||||
nms_threshold (float, defaults to 0.5): The threshold to use for box non-maximum suppression. Value in [0, 1].
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*, defaults to None):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
||||
max_num_det (int, *optional*, defaults to None): The maximum number of detections to return.
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, classes and boxes for an image
|
||||
in the batch as predicted by the model.
|
||||
"""
|
||||
if isinstance(classes[0], str):
|
||||
classes = [classes]
|
||||
|
||||
boxes_logits = outputs.decoder_coord_logits
|
||||
scores_logits = outputs.decoder_class_logits
|
||||
|
||||
# Inputs consistency check
|
||||
if target_sizes is None:
|
||||
height = (
|
||||
self.image_processor.size["height"]
|
||||
if "height" in self.image_processor.size
|
||||
else self.image_processor.size["shortest_edge"]
|
||||
)
|
||||
width = (
|
||||
self.image_processor.size["width"]
|
||||
if "width" in self.image_processor.size
|
||||
else self.image_processor.size["longest_edge"]
|
||||
)
|
||||
target_sizes = ((height, width),) * len(boxes_logits)
|
||||
elif len(target_sizes[0]) != 2:
|
||||
raise ValueError(
|
||||
"Each element of target_sizes must contain the size (height, width) of each image of the batch"
|
||||
)
|
||||
if len(target_sizes) != len(boxes_logits):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as output sequences")
|
||||
if len(classes) != len(boxes_logits):
|
||||
raise ValueError("Make sure that you pass in as many classes group as output sequences")
|
||||
|
||||
# Convert target_sizes to list for easier handling
|
||||
if isinstance(target_sizes, torch.Tensor):
|
||||
target_sizes = target_sizes.tolist()
|
||||
|
||||
scores, predicted_classes = compute_score(scores_logits)
|
||||
num_classes = scores_logits.shape[2]
|
||||
results = []
|
||||
for scores_img, box_per_img, image_size, class_names in zip(scores, boxes_logits, target_sizes, classes):
|
||||
results.append(
|
||||
_post_process_boxes_for_image(
|
||||
box_per_img,
|
||||
scores_img,
|
||||
predicted_classes,
|
||||
class_names,
|
||||
image_size,
|
||||
num_classes,
|
||||
score_threshold=score_threshold,
|
||||
nms_threshold=nms_threshold,
|
||||
max_num_det=max_num_det,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
@ -6587,6 +6587,20 @@ class OlmoePreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class OmDetTurboForObjectDetection(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class OmDetTurboPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class OneFormerForUniversalSegmentation(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
0
tests/models/omdet_turbo/__init__.py
Normal file
0
tests/models/omdet_turbo/__init__.py
Normal file
904
tests/models/omdet_turbo/test_modeling_omdet_turbo.py
Normal file
904
tests/models/omdet_turbo/test_modeling_omdet_turbo.py
Normal file
@ -0,0 +1,904 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch OmDet-Turbo model."""
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
|
||||
from transformers import OmDetTurboConfig, is_torch_available, is_vision_available
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import (
|
||||
require_timm,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import OmDetTurboForObjectDetection
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoProcessor
|
||||
|
||||
|
||||
class OmDetTurboModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=6,
|
||||
is_training=False,
|
||||
num_channels=3,
|
||||
max_text_len=7,
|
||||
num_classes=3,
|
||||
use_timm_backbone=False,
|
||||
backbone=None,
|
||||
apply_layernorm_after_vision_backbone=False,
|
||||
image_size=224,
|
||||
text_projection_in_dim=16,
|
||||
text_projection_out_dim=16,
|
||||
class_embed_dim=16,
|
||||
hidden_size=8,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_queries=20,
|
||||
encoder_in_channels=(16, 32, 64),
|
||||
encoder_dim_feedforward=32,
|
||||
num_projection_layers=1,
|
||||
decoder_n_points=4,
|
||||
num_feature_levels=3,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.is_training = is_training
|
||||
self.num_channels = num_channels
|
||||
self.max_text_len = max_text_len
|
||||
self.num_classes = num_classes
|
||||
self.use_timm_backbone = use_timm_backbone
|
||||
self.backbone = backbone
|
||||
self.apply_layernorm_after_vision_backbone = apply_layernorm_after_vision_backbone
|
||||
self.image_size = image_size
|
||||
self.text_projection_in_dim = text_projection_in_dim
|
||||
self.text_projection_out_dim = text_projection_out_dim
|
||||
self.class_embed_dim = class_embed_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_queries = num_queries
|
||||
self.encoder_in_channels = encoder_in_channels
|
||||
self.encoder_dim_feedforward = encoder_dim_feedforward
|
||||
self.num_projection_layers = num_projection_layers
|
||||
self.decoder_n_points = decoder_n_points
|
||||
self.num_feature_levels = num_feature_levels
|
||||
|
||||
self.encoder_seq_length_vision = self.image_size // 32
|
||||
self.decoder_seq_length = self.num_queries
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
input_ids_tasks = ids_tensor([self.batch_size, self.max_text_len], self.num_classes)
|
||||
input_ids_tasks = input_ids_tasks.to(torch_device)
|
||||
input_ids_classes = torch.cat(
|
||||
[ids_tensor([self.num_classes, self.max_text_len], self.num_classes) for _ in range(self.batch_size)]
|
||||
)
|
||||
input_ids_classes = input_ids_classes.to(torch_device)
|
||||
attention_mask_tasks = torch.ones_like(input_ids_tasks, device=torch_device)
|
||||
attention_mask_classes = torch.ones_like(input_ids_classes, device=torch_device)
|
||||
classes_structure = torch.ones(self.batch_size, dtype=torch.long, device=torch_device) * self.num_classes
|
||||
encoding = BatchFeature()
|
||||
encoding.update(
|
||||
{
|
||||
"pixel_values": pixel_values,
|
||||
"classes_input_ids": input_ids_classes,
|
||||
"classes_attention_mask": attention_mask_classes,
|
||||
"tasks_input_ids": input_ids_tasks,
|
||||
"tasks_attention_mask": attention_mask_tasks,
|
||||
"classes_structure": classes_structure,
|
||||
}
|
||||
)
|
||||
config = self.get_config()
|
||||
return config, encoding
|
||||
|
||||
def get_config(self):
|
||||
text_backbone = {
|
||||
"hidden_size": 16,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"intermediate_size": 16,
|
||||
"max_position_embeddings": 8,
|
||||
"model_type": "clip_text_model",
|
||||
}
|
||||
backbone_config = {
|
||||
"embed_dim": self.hidden_size,
|
||||
"depths": (1, 1, 1, 1),
|
||||
"num_heads": (1, 1, 1, 1),
|
||||
"window_size": 7,
|
||||
"image_size": self.image_size,
|
||||
"out_indices": (2, 3, 4),
|
||||
"model_type": "swin",
|
||||
}
|
||||
|
||||
return OmDetTurboConfig(
|
||||
text_config=text_backbone,
|
||||
backbone_config=backbone_config,
|
||||
use_timm_backbone=self.use_timm_backbone,
|
||||
backbone=self.backbone,
|
||||
apply_layernorm_after_vision_backbone=self.apply_layernorm_after_vision_backbone,
|
||||
decoder_num_layers=self.num_hidden_layers,
|
||||
image_size=self.image_size,
|
||||
encoder_in_channels=self.encoder_in_channels,
|
||||
num_queries=self.num_queries,
|
||||
encoder_layers=self.num_hidden_layers,
|
||||
encoder_projection_indices=[2] * self.num_projection_layers,
|
||||
encoder_attention_heads=self.num_attention_heads,
|
||||
decoder_num_heads=self.num_attention_heads,
|
||||
decoder_num_points=self.decoder_n_points,
|
||||
num_feature_levels=self.num_feature_levels,
|
||||
encoder_dim_feedforward=self.encoder_dim_feedforward,
|
||||
task_encoder_hidden_dim=self.encoder_dim_feedforward,
|
||||
decoder_dim_feedforward=self.encoder_dim_feedforward,
|
||||
class_embed_dim=self.class_embed_dim,
|
||||
text_projection_in_dim=self.text_projection_in_dim,
|
||||
text_projection_out_dim=self.text_projection_out_dim,
|
||||
encoder_hidden_dim=self.hidden_size,
|
||||
decoder_hidden_dim=self.hidden_size,
|
||||
vision_features_channels=[self.hidden_size, self.hidden_size, self.hidden_size],
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_object_detection_head_model(self, config, inputs_dict):
|
||||
model = OmDetTurboForObjectDetection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(**inputs_dict)
|
||||
|
||||
self.parent.assertEqual(result.decoder_coord_logits.shape, (self.batch_size, self.num_queries, 4))
|
||||
self.parent.assertEqual(
|
||||
result.decoder_class_logits.shape, (self.batch_size, self.num_queries, self.num_classes)
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class OmDetTurboModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (OmDetTurboForObjectDetection,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
pipeline_model_mapping = (
|
||||
{"zero-shot-object-detection": OmDetTurboForObjectDetection} if is_torch_available() else {}
|
||||
)
|
||||
|
||||
# special case for head models
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = OmDetTurboModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self,
|
||||
config_class=OmDetTurboConfig,
|
||||
has_text_modality=False,
|
||||
common_properties=["d_model", "encoder_attention_heads", "decoder_num_heads"],
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_object_detection_head_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_object_detection_head_model(config, inputs_dict)
|
||||
|
||||
@unittest.skip(
|
||||
reason="Unsupported as classes_input_ids are classes input are flattened by the processor: https://github.com/huggingface/transformers/issues/33669"
|
||||
)
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OmDet-Turbo does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OmDet-Turbo does not have 'input_ids' and 'attention_mask'")
|
||||
def test_torchscript_output_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OmDet-Turbo does not have 'input_ids' and 'attention_mask'")
|
||||
def test_torchscript_output_hidden_states(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OmDet-Turbo does not have 'input_ids' and 'attention_mask'")
|
||||
def test_torchscript_simple(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OmDet-Turbo does not have 'input_ids' and 'attention_mask'")
|
||||
def test_torchscript_output_hidden_state(self):
|
||||
pass
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
# rewrite as OmDet-Turbo does not have "input_ids" and "decoder_input_ids"
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model_embed_pre_resize = model.get_input_embeddings()
|
||||
type_model_embed_pre_resize = type(model_embed_pre_resize)
|
||||
|
||||
if self.model_tester.is_training is False:
|
||||
model.eval()
|
||||
|
||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
# Retrieve the embeddings and clone theme
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size)
|
||||
cloned_embeddings = model_embed.weight.clone()
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||
# Check to make sure the type of embeddings returned post resizing is same as type of input
|
||||
type_model_embed_post_resize = type(model_embed)
|
||||
self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize)
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
|
||||
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
# Input ids should be clamped to the maximum size of the vocabulary
|
||||
inputs_dict["tasks_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
|
||||
# make sure that classes_input_ids are resized as well
|
||||
if "classes_input_ids" in inputs_dict:
|
||||
inputs_dict["classes_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
||||
models_equal = True
|
||||
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
self.assertTrue(new_model_vocab_size + 10, model_vocab_size)
|
||||
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||
|
||||
self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
|
||||
self.assertTrue(new_model_vocab_size, model.vocab_size)
|
||||
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
|
||||
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||
|
||||
# Check that resizing a model to a multiple of pad_to_multiple leads to a model of exactly that size
|
||||
target_dimension = 128
|
||||
model_embed = model.resize_token_embeddings(target_dimension, pad_to_multiple_of=64)
|
||||
self.assertTrue(model_embed.weight.shape[0], target_dimension)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",
|
||||
):
|
||||
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
|
||||
|
||||
# Overwrite as `init_reference_points` is not batch dependent and contains `inf` values
|
||||
def test_batching_equivalence(self):
|
||||
"""
|
||||
Tests that the model supports batching and that the output is nearly the same for the same input in
|
||||
different batch sizes.
|
||||
(Why "nearly the same" not "exactly the same"? Batching uses different matmul shapes, which often leads to
|
||||
different results: https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535)
|
||||
"""
|
||||
|
||||
def get_tensor_equivalence_function(batched_input):
|
||||
# models operating on continuous spaces have higher abs difference than LMs
|
||||
# instead, we can rely on cos distance for image/speech models, similar to `diffusers`
|
||||
if "input_ids" not in batched_input:
|
||||
return lambda tensor1, tensor2: (
|
||||
1.0 - F.cosine_similarity(tensor1.float().flatten(), tensor2.float().flatten(), dim=0, eps=1e-38)
|
||||
)
|
||||
return lambda tensor1, tensor2: torch.max(torch.abs(tensor1 - tensor2))
|
||||
|
||||
def recursive_check(batched_object, single_row_object, model_name, key):
|
||||
if isinstance(batched_object, (list, tuple)):
|
||||
for batched_object_value, single_row_object_value in zip(batched_object, single_row_object):
|
||||
recursive_check(batched_object_value, single_row_object_value, model_name, key)
|
||||
elif isinstance(batched_object, dict):
|
||||
for batched_object_value, single_row_object_value in zip(
|
||||
batched_object.values(), single_row_object.values()
|
||||
):
|
||||
recursive_check(batched_object_value, single_row_object_value, model_name, key)
|
||||
# do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects
|
||||
elif batched_object is None or not isinstance(batched_object, torch.Tensor):
|
||||
return
|
||||
elif batched_object.dim() == 0:
|
||||
return
|
||||
elif key != "init_reference_points":
|
||||
# init
|
||||
# indexing the first element does not always work
|
||||
# e.g. models that output similarity scores of size (N, M) would need to index [0, 0]
|
||||
slice_ids = [slice(0, index) for index in single_row_object.shape]
|
||||
batched_row = batched_object[slice_ids]
|
||||
self.assertFalse(
|
||||
torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}"
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}"
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}"
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.isinf(single_row_object).any(),
|
||||
f"Single row output has `inf` in {model_name} for key={key}",
|
||||
)
|
||||
self.assertTrue(
|
||||
(equivalence(batched_row, single_row_object)) <= 1e-03,
|
||||
msg=(
|
||||
f"Batched and Single row outputs are not equal in {model_name} for key={key}. "
|
||||
f"Difference={equivalence(batched_row, single_row_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
equivalence = get_tensor_equivalence_function(batched_input)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config.output_hidden_states = True
|
||||
|
||||
model_name = model_class.__name__
|
||||
if hasattr(self.model_tester, "prepare_config_and_inputs_for_model_class"):
|
||||
config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class)
|
||||
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
batch_size = self.model_tester.batch_size
|
||||
single_row_input = {}
|
||||
for key, value in batched_input_prepared.items():
|
||||
single_batch_shape = value.shape[0] // batch_size
|
||||
single_row_input[key] = value[:single_batch_shape]
|
||||
|
||||
with torch.no_grad():
|
||||
model_batched_output = model(**batched_input_prepared)
|
||||
model_row_output = model(**single_row_input)
|
||||
|
||||
if isinstance(model_batched_output, torch.Tensor):
|
||||
model_batched_output = {"model_output": model_batched_output}
|
||||
model_row_output = {"model_output": model_row_output}
|
||||
|
||||
for key in model_batched_output:
|
||||
# DETR starts from zero-init queries to decoder, leading to cos_similarity = `nan`
|
||||
if hasattr(self, "zero_init_hidden_state") and "decoder_hidden_states" in key:
|
||||
model_batched_output[key] = model_batched_output[key][1:]
|
||||
model_row_output[key] = model_row_output[key][1:]
|
||||
if key in ("decoder_class_logits", "decoder_classes", "encoder_class_logits"):
|
||||
# check if all elements are close to 0, if so skip the test as the test strugles with comparing
|
||||
# tensors with all elements close to 0
|
||||
if torch.allclose(
|
||||
model_batched_output[key], torch.zeros_like(model_batched_output[key]), atol=1e-6
|
||||
) and torch.allclose(model_row_output[key], torch.zeros_like(model_row_output[key]), atol=1e-6):
|
||||
continue
|
||||
|
||||
recursive_check(model_batched_output[key], model_row_output[key], model_name, key)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions[-1]
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions[-1]
|
||||
self.assertEqual(
|
||||
len(attentions), self.model_tester.num_hidden_layers * self.model_tester.num_projection_layers
|
||||
)
|
||||
# Rest of the shape seems to depend on backbone output shapes and image size
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
self.model_tester.encoder_seq_length_vision**2,
|
||||
self.model_tester.encoder_seq_length_vision**2,
|
||||
],
|
||||
)
|
||||
# decoder attentions
|
||||
decoder_attentions = outputs.decoder_attentions[0]
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
self.model_tester.num_queries + self.model_tester.max_text_len,
|
||||
self.model_tester.num_queries + self.model_tester.max_text_len,
|
||||
],
|
||||
)
|
||||
|
||||
# cross attentions
|
||||
cross_attentions = outputs.decoder_attentions[-1]
|
||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
self.model_tester.num_feature_levels,
|
||||
self.model_tester.decoder_n_points,
|
||||
],
|
||||
)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
self_attentions = outputs.encoder_attentions[-1]
|
||||
|
||||
self.assertEqual(
|
||||
len(self_attentions), self.model_tester.num_hidden_layers * self.model_tester.num_projection_layers
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
self.model_tester.encoder_seq_length_vision**2,
|
||||
self.model_tester.encoder_seq_length_vision**2,
|
||||
],
|
||||
)
|
||||
|
||||
# overwrite since encoder_hidden_states are 3-dim and not 2-dim
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_projection_layers + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
seq_len = self.model_tester.encoder_seq_length_vision
|
||||
|
||||
self.assertListEqual(list(hidden_states[0].shape[-3:]), [self.model_tester.hidden_size, seq_len, seq_len])
|
||||
|
||||
hidden_states = outputs.decoder_hidden_states
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
self.assertIsInstance(hidden_states, (list, tuple))
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.decoder_seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# removed retain_grad and grad on decoder_hidden_states, as queries don't require grad
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
output = outputs[0]
|
||||
|
||||
encoder_hidden_states = outputs.encoder_hidden_states[0]
|
||||
encoder_attentions = outputs.encoder_attentions[0][0]
|
||||
encoder_hidden_states.retain_grad()
|
||||
encoder_attentions.retain_grad()
|
||||
|
||||
cross_attentions = outputs.decoder_attentions[-1][0]
|
||||
cross_attentions.retain_grad()
|
||||
|
||||
output.flatten()[0].backward(retain_graph=True)
|
||||
|
||||
self.assertIsNotNone(encoder_hidden_states.grad)
|
||||
self.assertIsNotNone(encoder_attentions.grad)
|
||||
self.assertIsNotNone(cross_attentions.grad)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
if (
|
||||
"embeddings" in name
|
||||
or ".fc" in name
|
||||
or "decoder.channel_projection_layers" in name
|
||||
or "query_position_head" in name
|
||||
or "decoder.encoder_vision_features" in name
|
||||
):
|
||||
continue
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} seems not properly initialized",
|
||||
)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def prepare_text():
|
||||
classes = ["cat", "remote"]
|
||||
task = "Detect {}.".format(", ".join(classes))
|
||||
return classes, task
|
||||
|
||||
|
||||
def prepare_img_batched():
|
||||
url1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
url2 = "http://images.cocodataset.org/train2017/000000257813.jpg"
|
||||
url3 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||
|
||||
return [Image.open(BytesIO(requests.get(url).content)).convert("RGB") for url in [url1, url2, url3]]
|
||||
|
||||
|
||||
def prepare_text_batched():
|
||||
classes1 = ["cat", "remote"]
|
||||
classes2 = ["boat"]
|
||||
classes3 = ["statue", "trees", "torch"]
|
||||
|
||||
task1 = "Detect {}.".format(", ".join(classes1))
|
||||
task2 = "Detect all the boat in the image."
|
||||
task3 = "Focus on the foreground, detect statue, torch and trees."
|
||||
return [classes1, classes2, classes3], [task1, task2, task3]
|
||||
|
||||
|
||||
@require_timm
|
||||
@require_vision
|
||||
@slow
|
||||
class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_processor(self):
|
||||
return AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") if is_vision_available() else None
|
||||
|
||||
def test_inference_object_detection_head(self):
|
||||
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(torch_device)
|
||||
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
classes, task = prepare_text()
|
||||
encoding = processor(images=image, text=classes, task=task, return_tensors="pt").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoding)
|
||||
|
||||
expected_shape_coord_logits = torch.Size((1, model.config.num_queries, 4))
|
||||
expected_shape_class_logits = torch.Size((1, model.config.num_queries, 2))
|
||||
self.assertEqual(outputs.decoder_coord_logits.shape, expected_shape_coord_logits)
|
||||
self.assertEqual(outputs.decoder_class_logits.shape, expected_shape_class_logits)
|
||||
|
||||
expected_class_logits = torch.tensor([[[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]]).to(
|
||||
torch_device
|
||||
)
|
||||
expected_coord_logits = torch.tensor(
|
||||
[[[0.2550, 0.5501, 0.4738, 0.8745], [0.7695, 0.4121, 0.4603, 0.7244], [0.7691, 0.4117, 0.4603, 0.7214]]]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.decoder_class_logits[:3, :3], expected_class_logits, atol=1e-1))
|
||||
self.assertTrue(torch.allclose(outputs.decoder_coord_logits[:3, :3], expected_coord_logits, atol=1e-3))
|
||||
|
||||
# verify grounded postprocessing
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device)
|
||||
expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(torch_device)
|
||||
|
||||
self.assertEqual(len(results["scores"]), 4)
|
||||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2))
|
||||
|
||||
expected_classes = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["classes"], expected_classes)
|
||||
|
||||
def test_inference_object_detection_head_fp16(self):
|
||||
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(
|
||||
torch_device, dtype=torch.float16
|
||||
)
|
||||
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
classes, task = prepare_text()
|
||||
encoding = processor(images=image, text=classes, task=task, return_tensors="pt").to(
|
||||
torch_device, dtype=torch.float16
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoding)
|
||||
|
||||
expected_shape_coord_logits = torch.Size((1, model.config.num_queries, 4))
|
||||
expected_shape_class_logits = torch.Size((1, model.config.num_queries, 2))
|
||||
self.assertEqual(outputs.decoder_coord_logits.shape, expected_shape_coord_logits)
|
||||
self.assertEqual(outputs.decoder_class_logits.shape, expected_shape_class_logits)
|
||||
|
||||
expected_class_logits = torch.tensor([[[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]]).to(
|
||||
torch_device, dtype=torch.float16
|
||||
)
|
||||
expected_coord_logits = torch.tensor(
|
||||
[[[0.2550, 0.5501, 0.4738, 0.8745], [0.7695, 0.4121, 0.4603, 0.7244], [0.7691, 0.4117, 0.4603, 0.7214]]]
|
||||
).to(torch_device, dtype=torch.float16)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.decoder_class_logits[:3, :3], expected_class_logits, atol=1e-1))
|
||||
self.assertTrue(torch.allclose(outputs.decoder_coord_logits[:3, :3], expected_coord_logits, atol=1e-3))
|
||||
|
||||
# verify grounded postprocessing
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device, dtype=torch.float16)
|
||||
expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(
|
||||
torch_device, dtype=torch.float16
|
||||
)
|
||||
|
||||
self.assertEqual(len(results["scores"]), 4)
|
||||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-1))
|
||||
|
||||
expected_classes = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["classes"], expected_classes)
|
||||
|
||||
def test_inference_object_detection_head_no_task(self):
|
||||
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(torch_device)
|
||||
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
classes, _ = prepare_text()
|
||||
encoding = processor(images=image, text=classes, return_tensors="pt").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoding)
|
||||
|
||||
expected_shape_coord_logits = torch.Size((1, model.config.num_queries, 4))
|
||||
expected_shape_class_logits = torch.Size((1, model.config.num_queries, 2))
|
||||
self.assertEqual(outputs.decoder_coord_logits.shape, expected_shape_coord_logits)
|
||||
self.assertEqual(outputs.decoder_class_logits.shape, expected_shape_class_logits)
|
||||
|
||||
expected_class_logits = torch.tensor([[[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]]).to(
|
||||
torch_device
|
||||
)
|
||||
expected_coord_logits = torch.tensor(
|
||||
[[[0.2550, 0.5501, 0.4738, 0.8745], [0.7695, 0.4121, 0.4603, 0.7244], [0.7691, 0.4117, 0.4603, 0.7214]]]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.decoder_class_logits[:3, :3], expected_class_logits, atol=1e-1))
|
||||
self.assertTrue(torch.allclose(outputs.decoder_coord_logits[:3, :3], expected_coord_logits, atol=1e-3))
|
||||
|
||||
# verify grounded postprocessing
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device)
|
||||
expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(torch_device)
|
||||
|
||||
self.assertEqual(len(results["scores"]), 4)
|
||||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2))
|
||||
|
||||
expected_classes = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["classes"], expected_classes)
|
||||
|
||||
def test_inference_object_detection_head_batched(self):
|
||||
torch_device = "cpu"
|
||||
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(torch_device)
|
||||
|
||||
processor = self.default_processor
|
||||
images_batched = prepare_img_batched()
|
||||
classes_batched, tasks_batched = prepare_text_batched()
|
||||
encoding = processor(images=images_batched, text=classes_batched, task=tasks_batched, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoding)
|
||||
|
||||
expected_shape_coord_logits = torch.Size((len(images_batched), model.config.num_queries, 4))
|
||||
expected_shape_class_logits = torch.Size((len(images_batched), model.config.num_queries, 3))
|
||||
self.assertEqual(outputs.decoder_coord_logits.shape, expected_shape_coord_logits)
|
||||
self.assertEqual(outputs.decoder_class_logits.shape, expected_shape_class_logits)
|
||||
|
||||
expected_class_logits = torch.tensor(
|
||||
[[[0.9427, -2.5958, -7.7601]], [[-2.3408, -9.3516, -9.3516]], [[1.0740, -2.3315, -1.1885]]]
|
||||
).to(torch_device)
|
||||
|
||||
expected_coord_logits = torch.tensor(
|
||||
[[[0.2550, 0.5501, 0.4738]], [[0.2535, 0.6006, 0.0353]], [[0.3742, 0.3337, 0.0666]]]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.decoder_class_logits[:, :1, :3], expected_class_logits, atol=1e-1))
|
||||
self.assertTrue(torch.allclose(outputs.decoder_coord_logits[:, :1, :3], expected_coord_logits, atol=1e-3))
|
||||
|
||||
# verify grounded postprocessing
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs,
|
||||
classes=classes_batched,
|
||||
target_sizes=[image.size[::-1] for image in images_batched],
|
||||
score_threshold=0.2,
|
||||
)
|
||||
expected_scores = torch.tensor([0.7675, 0.3016, 0.7454]).to(torch_device)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[
|
||||
[39.8870, 70.3522, 176.7424, 118.0354],
|
||||
[146.5446, 219.7132, 209.6983, 251.0456],
|
||||
[545.3470, 209.9055, 651.9860, 502.1882],
|
||||
]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertListEqual([len(result["scores"]) for result in results], [4, 4, 6])
|
||||
self.assertTrue(
|
||||
torch.allclose(torch.stack([result["scores"][0] for result in results]), expected_scores, atol=1e-2)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(torch.stack([result["boxes"][0, :] for result in results]), expected_slice_boxes, atol=1e-2)
|
||||
)
|
||||
|
||||
expected_classes = [
|
||||
["remote", "cat", "remote", "cat"],
|
||||
["boat", "boat", "boat", "boat"],
|
||||
["statue", "trees", "trees", "torch", "statue", "statue"],
|
||||
]
|
||||
self.assertListEqual([result["classes"] for result in results], expected_classes)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_inference_object_detection_head_equivalence_cpu_gpu(self):
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
classes, task = prepare_text()
|
||||
encoding = processor(images=image, text=classes, task=task, return_tensors="pt")
|
||||
# 1. run model on CPU
|
||||
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
|
||||
|
||||
with torch.no_grad():
|
||||
cpu_outputs = model(**encoding)
|
||||
|
||||
# 2. run model on GPU
|
||||
model.to("cuda")
|
||||
encoding = encoding.to("cuda")
|
||||
with torch.no_grad():
|
||||
gpu_outputs = model(**encoding)
|
||||
|
||||
# 3. assert equivalence
|
||||
expected_class_logits = torch.tensor([[[0.9427, -2.5958], [0.2105, -3.4569], [-2.6364, -4.1610]]])
|
||||
expected_coord_logits = torch.tensor(
|
||||
[[[0.2550, 0.5501, 0.4738, 0.8745], [0.7695, 0.4121, 0.4603, 0.7244], [0.7691, 0.4117, 0.4603, 0.7214]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(cpu_outputs.decoder_class_logits[:3, :3], expected_class_logits, atol=1e-1))
|
||||
self.assertTrue(torch.allclose(cpu_outputs.decoder_coord_logits[:3, :3], expected_coord_logits, atol=1e-3))
|
||||
|
||||
# verify grounded postprocessing
|
||||
results_cpu = processor.post_process_grounded_object_detection(
|
||||
cpu_outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
result_gpu = processor.post_process_grounded_object_detection(
|
||||
gpu_outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-2))
|
||||
self.assertTrue(torch.allclose(results_cpu["boxes"][0, :], result_gpu["boxes"][0, :].cpu(), atol=1e-2))
|
363
tests/models/omdet_turbo/test_processor_omdet_turbo.py
Normal file
363
tests/models/omdet_turbo/test_processor_omdet_turbo.py
Normal file
@ -0,0 +1,363 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import AutoProcessor, CLIPTokenizerFast, OmDetTurboProcessor
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
IMAGE_MEAN = [123.675, 116.28, 103.53]
|
||||
IMAGE_STD = [58.395, 57.12, 57.375]
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.omdet_turbo.modeling_omdet_turbo import OmDetTurboObjectDetectionOutput
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import DetrImageProcessor
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class OmDetTurboProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = OmDetTurboProcessor
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
image_processor = DetrImageProcessor()
|
||||
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
processor = OmDetTurboProcessor(image_processor, tokenizer)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
self.input_keys = [
|
||||
"tasks_input_ids",
|
||||
"tasks_attention_mask",
|
||||
"classes_input_ids",
|
||||
"classes_attention_mask",
|
||||
"classes_structure",
|
||||
"pixel_values",
|
||||
"pixel_mask",
|
||||
]
|
||||
|
||||
self.batch_size = 5
|
||||
self.num_queries = 5
|
||||
self.embed_dim = 3
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def prepare_image_inputs(self):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
"""
|
||||
|
||||
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
|
||||
|
||||
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||
|
||||
return image_inputs
|
||||
|
||||
def get_fake_omdet_turbo_output(self):
|
||||
torch.manual_seed(42)
|
||||
return OmDetTurboObjectDetectionOutput(
|
||||
decoder_coord_logits=torch.rand(self.batch_size, self.num_queries, 4),
|
||||
decoder_class_logits=torch.rand(self.batch_size, self.num_queries, self.embed_dim),
|
||||
)
|
||||
|
||||
def get_fake_omdet_turbo_classes(self):
|
||||
return [[f"class{i}_{j}" for i in range(self.num_queries)] for j in range(self.batch_size)]
|
||||
|
||||
def test_post_process_grounded_object_detection(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
|
||||
omdet_turbo_output = self.get_fake_omdet_turbo_output()
|
||||
omdet_turbo_classes = self.get_fake_omdet_turbo_classes()
|
||||
|
||||
post_processed = processor.post_process_grounded_object_detection(
|
||||
omdet_turbo_output, omdet_turbo_classes, target_sizes=[(400, 30) for _ in range(self.batch_size)]
|
||||
)
|
||||
|
||||
self.assertEqual(len(post_processed), self.batch_size)
|
||||
self.assertEqual(list(post_processed[0].keys()), ["boxes", "scores", "classes"])
|
||||
self.assertEqual(post_processed[0]["boxes"].shape, (self.num_queries, 4))
|
||||
self.assertEqual(post_processed[0]["scores"].shape, (self.num_queries,))
|
||||
expected_scores = torch.tensor([0.7310, 0.6579, 0.6513, 0.6444, 0.6252])
|
||||
self.assertTrue(torch.allclose(post_processed[0]["scores"], expected_scores, atol=1e-4))
|
||||
|
||||
expected_box_slice = torch.tensor([14.9657, 141.2052, 30.0000, 312.9670])
|
||||
self.assertTrue(torch.allclose(post_processed[0]["boxes"][0], expected_box_slice, atol=1e-4))
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = OmDetTurboProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
||||
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
|
||||
|
||||
processor = OmDetTurboProcessor.from_pretrained(
|
||||
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
|
||||
)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, CLIPTokenizerFast)
|
||||
|
||||
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.image_processor, DetrImageProcessor)
|
||||
|
||||
def test_image_processor(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor).image_processor
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
input_image_proc = image_processor(image_input, return_tensors="np")
|
||||
input_processor = processor(images=image_input, return_tensors="np")
|
||||
|
||||
for key in input_image_proc.keys():
|
||||
self.assertAlmostEqual(input_image_proc[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||
|
||||
def test_tokenizer(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor).tokenizer
|
||||
|
||||
input_str = "lower newer"
|
||||
|
||||
encoded_processor = processor(text=input_str, padding="max_length", truncation=True, max_length=77)
|
||||
|
||||
encoded_tok = tokenizer(input_str, padding="max_length", truncation=True, max_length=77)
|
||||
|
||||
for key in encoded_tok.keys():
|
||||
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
||||
|
||||
def test_processor(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
|
||||
input_tasks = "task"
|
||||
input_classes = ["class1", "class2"]
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
input_processor = processor(images=image_input, text=input_classes, task=input_tasks, return_tensors="pt")
|
||||
|
||||
for key in self.input_keys:
|
||||
assert torch.is_tensor(input_processor[key])
|
||||
# test if it raises when no input is passed
|
||||
with pytest.raises(ValueError):
|
||||
processor()
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
|
||||
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
||||
|
||||
decoded_processor = processor.batch_decode(predicted_ids)
|
||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||
|
||||
self.assertListEqual(decoded_tok, decoded_processor)
|
||||
|
||||
def test_model_input_names(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = OmDetTurboProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
|
||||
input_tasks = "task"
|
||||
input_classes = ["class1", "class2"]
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(images=image_input, text=input_classes, task=input_tasks, return_tensors="pt")
|
||||
|
||||
self.assertListEqual(list(inputs.keys()), self.input_keys)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_tokenizer_defaults_preserved_by_kwargs(self):
|
||||
# Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes.
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(images=image_input, text=[input_str], task=input_str, return_tensors="pt")
|
||||
|
||||
self.assertEqual(len(inputs["tasks_input_ids"][0]), 117)
|
||||
self.assertEqual(len(inputs["classes_input_ids"][0]), 117)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs(self):
|
||||
# Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes.
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(images=image_input, text=[input_str], task=input_str, return_tensors="pt", max_length=112)
|
||||
|
||||
self.assertEqual(len(inputs["tasks_input_ids"][0]), 112)
|
||||
self.assertEqual(len(inputs["classes_input_ids"][0]), 112)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs(self):
|
||||
# Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes.
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(
|
||||
images=image_input,
|
||||
text=[input_str],
|
||||
task=input_str,
|
||||
return_tensors="pt",
|
||||
size={"height": 214, "width": 214},
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
self.assertEqual(len(inputs["tasks_input_ids"][0]), 76)
|
||||
self.assertEqual(len(inputs["classes_input_ids"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
# Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes.
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer", "upper older longer string"]
|
||||
image_input = self.prepare_image_inputs() * 2
|
||||
inputs = processor(
|
||||
images=image_input,
|
||||
text=[input_str],
|
||||
task=input_str,
|
||||
return_tensors="pt",
|
||||
size={"height": 214, "width": 214},
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["tasks_input_ids"][0]), 6)
|
||||
self.assertEqual(len(inputs["classes_input_ids"][0]), 6)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested(self):
|
||||
# Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes.
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76, "task": input_str},
|
||||
}
|
||||
|
||||
inputs = processor(images=image_input, text=[input_str], **all_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["tasks_input_ids"][0]), 76)
|
||||
self.assertEqual(len(inputs["classes_input_ids"][0]), 76)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_structured_kwargs_nested_from_dict(self):
|
||||
# Rewrite as OmDet-Turbo processor outputs "input_ids" for both tasks and classes.
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76, "task": input_str},
|
||||
}
|
||||
|
||||
inputs = processor(images=image_input, text=[input_str], **all_kwargs)
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["tasks_input_ids"][0]), 76)
|
||||
self.assertEqual(len(inputs["classes_input_ids"][0]), 76)
|
@ -173,7 +173,13 @@ MODEL_NAMES_WITH_SAME_CONFIG = {
|
||||
"XLS-R": "Wav2Vec2",
|
||||
"XLSR-Wav2Vec2": "Wav2Vec2",
|
||||
}
|
||||
MODEL_NAMES_TO_IGNORE = ["CLIPVisionModel", "SiglipVisionModel", "ChineseCLIPVisionModel", "Qwen2AudioEncoder"]
|
||||
MODEL_NAMES_TO_IGNORE = [
|
||||
"ChineseCLIPVisionModel",
|
||||
"CLIPTextModel",
|
||||
"CLIPVisionModel",
|
||||
"Qwen2AudioEncoder",
|
||||
"SiglipVisionModel",
|
||||
]
|
||||
|
||||
|
||||
def get_model_table_from_auto_modules() -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user