mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add YOLOS (#16848)
* First draft * Add YolosForObjectDetection * Make forward pass work * Add mid position embeddings * Add interpolation of position encodings * Add expected values * Add YOLOS to tests * Add integration test * Support tiny model as well * Support all models in conversion script * Remove mid_pe_size attribute * Make more tests pass * Add model to README and fix config * Add copied from statements * Rename base_model_prefix to vit * Add missing YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP * Apply suggestions from code review * Apply more suggestions from code review * Convert remaining checkpoints * Improve docstrings * Add YolosFeatureExtractor * Add feature extractor to docs * Add corresponding tests * Fix style * Fix docs * Apply suggestion from code review * Fix bad rebase * Fix some more bad rebase * Fix missing character * Improve docs and variable names Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
f275e593bf
commit
1ac698744c
@ -340,6 +340,7 @@ AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Ch
|
||||
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[XLS-R](https://huggingface.co/docs/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
||||
1. **[YOLOS](https://huggingface.co/docs/transformers/main/model_doc/yolos)** (from Huazhong University of Science & Technology) released with the paper [You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection](https://arxiv.org/abs/2106.00666) by Yuxin Fang, Bencheng Liao, Xinggang Wang, Jiemin Fang, Jiyang Qi, Rui Wu, Jianwei Niu, Wenyu Liu.
|
||||
1. **[YOSO](https://huggingface.co/docs/transformers/model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.
|
||||
1. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||
|
||||
|
@ -318,6 +318,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
||||
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
1. **[XLS-R](https://huggingface.co/docs/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
||||
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[YOLOS](https://huggingface.co/docs/transformers/main/model_doc/yolos)** (from Huazhong University of Science & Technology) released with the paper [You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection](https://arxiv.org/abs/2106.00666) by Yuxin Fang, Bencheng Liao, Xinggang Wang, Jiemin Fang, Jiyang Qi, Rui Wu, Jianwei Niu, Wenyu Liu.
|
||||
1. **[YOSO](https://huggingface.co/docs/transformers/model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.
|
||||
1. 새로운 모델을 올리고 싶나요? 우리가 **상세한 가이드와 템플릿** 으로 새로운 모델을 올리도록 도와드릴게요. 가이드와 템플릿은 이 저장소의 [`templates`](./templates) 폴더에서 확인하실 수 있습니다. [컨트리뷰션 가이드라인](./CONTRIBUTING.md)을 꼭 확인해주시고, PR을 올리기 전에 메인테이너에게 연락하거나 이슈를 오픈해 피드백을 받으시길 바랍니다.
|
||||
|
||||
|
@ -342,6 +342,7 @@ conda install -c huggingface transformers
|
||||
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (来自 Google/CMU) 伴随论文 [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) 由 Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 发布。
|
||||
1. **[XLS-R](https://huggingface.co/docs/transformers/model_doc/xls_r)** (来自 Facebook AI) 伴随论文 [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) 由 Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli 发布。
|
||||
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (来自 Facebook AI) 伴随论文 [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) 由 Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli 发布。
|
||||
1. **[YOLOS](https://huggingface.co/docs/transformers/main/model_doc/yolos)** (来自 Huazhong University of Science & Technology) 伴随论文 [You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection](https://arxiv.org/abs/2106.00666) 由 Yuxin Fang, Bencheng Liao, Xinggang Wang, Jiemin Fang, Jiyang Qi, Rui Wu, Jianwei Niu, Wenyu Liu 发布。
|
||||
1. **[YOSO](https://huggingface.co/docs/transformers/model_doc/yoso)** (来自 the University of Wisconsin - Madison) 伴随论文 [You Only Sample (Almost) 由 Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh 发布。
|
||||
1. 想要贡献新的模型?我们这里有一份**详细指引和模板**来引导你添加新的模型。你可以在 [`templates`](./templates) 目录中找到他们。记得查看 [贡献指南](./CONTRIBUTING.md) 并在开始写 PR 前联系维护人员或开一个新的 issue 来获得反馈。
|
||||
|
||||
|
@ -354,6 +354,7 @@ conda install -c huggingface transformers
|
||||
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
1. **[XLS-R](https://huggingface.co/docs/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
||||
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[YOLOS](https://huggingface.co/docs/transformers/main/model_doc/yolos)** (from Huazhong University of Science & Technology) released with the paper [You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection](https://arxiv.org/abs/2106.00666) by Yuxin Fang, Bencheng Liao, Xinggang Wang, Jiemin Fang, Jiyang Qi, Rui Wu, Jianwei Niu, Wenyu Liu.
|
||||
1. **[YOSO](https://huggingface.co/docs/transformers/model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.
|
||||
1. 想要貢獻新的模型?我們這裡有一份**詳細指引和模板**來引導你加入新的模型。你可以在 [`templates`](./templates) 目錄中找到它們。記得查看[貢獻指引](./CONTRIBUTING.md)並在開始寫 PR 前聯繫維護人員或開一個新的 issue 來獲得 feedbacks。
|
||||
|
||||
|
@ -376,6 +376,8 @@
|
||||
title: XLSR-Wav2Vec2
|
||||
- local: model_doc/xls_r
|
||||
title: XLS-R
|
||||
- local: model_doc/yolos
|
||||
title: YOLOS
|
||||
- local: model_doc/yoso
|
||||
title: YOSO
|
||||
title: Models
|
||||
|
@ -160,6 +160,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret
|
||||
1. **[XLNet](model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
1. **[XLSR-Wav2Vec2](model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[XLS-R](model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
||||
1. **[YOLOS](model_doc/yolos)** (from Huazhong University of Science & Technology) released with the paper [You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection](https://arxiv.org/abs/2106.00666) by Yuxin Fang, Bencheng Liao, Xinggang Wang, Jiemin Fang, Jiyang Qi, Rui Wu, Jianwei Niu, Wenyu Liu.
|
||||
1. **[YOSO](model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.
|
||||
|
||||
|
||||
@ -274,6 +275,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| XLM-RoBERTa-XL | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| YOLOS | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| YOSO | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
||||
<!-- End table-->
|
||||
|
60
docs/source/en/model_doc/yolos.mdx
Normal file
60
docs/source/en/model_doc/yolos.mdx
Normal file
@ -0,0 +1,60 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# YOLOS
|
||||
|
||||
## Overview
|
||||
|
||||
The YOLOS model was proposed in [You Only Look at One Sequence: Rethinking Transformer in Vision through Object Detection](https://arxiv.org/abs/2106.00666) by Yuxin Fang, Bencheng Liao, Xinggang Wang, Jiemin Fang, Jiyang Qi, Rui Wu, Jianwei Niu, Wenyu Liu.
|
||||
YOLOS proposes to just leverage the plain [Vision Transformer (ViT)](vit) for object detection, inspired by DETR. It turns out that a base-sized encoder-only Transformer can also achieve 42 AP on COCO, similar to DETR and much more complex frameworks such as Faster R-CNN.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Can Transformer perform 2D object- and region-level recognition from a pure sequence-to-sequence perspective with minimal knowledge about the 2D spatial structure? To answer this question, we present You Only Look at One Sequence (YOLOS), a series of object detection models based on the vanilla Vision Transformer with the fewest possible modifications, region priors, as well as inductive biases of the target task. We find that YOLOS pre-trained on the mid-sized ImageNet-1k dataset only can already achieve quite competitive performance on the challenging COCO object detection benchmark, e.g., YOLOS-Base directly adopted from BERT-Base architecture can obtain 42.0 box AP on COCO val. We also discuss the impacts as well as limitations of current pre-train schemes and model scaling strategies for Transformer in vision through YOLOS.*
|
||||
|
||||
Tips:
|
||||
|
||||
- One can use [`YolosFeatureExtractor`] for preparing images (and optional targets) for the model. Contrary to [DETR](detr), YOLOS doesn't require a `pixel_mask` to be created.
|
||||
- Demo notebooks (regarding inference and fine-tuning on custom data) can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/YOLOS).
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/yolos_architecture.png"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> YOLOS architecture. Taken from the <a href="https://arxiv.org/abs/2106.00666">original paper</a>.</small>
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/hustvl/YOLOS).
|
||||
|
||||
## YolosConfig
|
||||
|
||||
[[autodoc]] YolosConfig
|
||||
|
||||
|
||||
## YolosFeatureExtractor
|
||||
|
||||
[[autodoc]] YolosFeatureExtractor
|
||||
- __call__
|
||||
- pad
|
||||
- post_process
|
||||
- post_process_segmentation
|
||||
- post_process_panoptic
|
||||
|
||||
|
||||
## YolosModel
|
||||
|
||||
[[autodoc]] YolosModel
|
||||
- forward
|
||||
|
||||
|
||||
## YolosForObjectDetection
|
||||
|
||||
[[autodoc]] YolosForObjectDetection
|
||||
- forward
|
@ -320,6 +320,7 @@ _import_structure = {
|
||||
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
||||
"models.xlm_roberta_xl": ["XLM_ROBERTA_XL_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaXLConfig"],
|
||||
"models.xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"],
|
||||
"models.yolos": ["YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP", "YolosConfig"],
|
||||
"models.yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"],
|
||||
"onnx": [],
|
||||
"pipelines": [
|
||||
@ -551,6 +552,7 @@ if is_vision_available():
|
||||
_import_structure["models.vilt"].append("ViltFeatureExtractor")
|
||||
_import_structure["models.vilt"].append("ViltProcessor")
|
||||
_import_structure["models.vit"].append("ViTFeatureExtractor")
|
||||
_import_structure["models.yolos"].append("YolosFeatureExtractor")
|
||||
else:
|
||||
from .utils import dummy_vision_objects
|
||||
|
||||
@ -1681,6 +1683,14 @@ if is_torch_available():
|
||||
"load_tf_weights_in_xlnet",
|
||||
]
|
||||
)
|
||||
_import_structure["models.yolos"].extend(
|
||||
[
|
||||
"YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"YolosForObjectDetection",
|
||||
"YolosModel",
|
||||
"YolosPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.yoso"].extend(
|
||||
[
|
||||
"YOSO_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -2696,6 +2706,7 @@ if TYPE_CHECKING:
|
||||
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
||||
from .models.xlm_roberta_xl import XLM_ROBERTA_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaXLConfig
|
||||
from .models.xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
|
||||
from .models.yolos import YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP, YolosConfig
|
||||
from .models.yoso import YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP, YosoConfig
|
||||
|
||||
# Pipelines
|
||||
@ -2901,6 +2912,7 @@ if TYPE_CHECKING:
|
||||
from .models.segformer import SegformerFeatureExtractor
|
||||
from .models.vilt import ViltFeatureExtractor, ViltProcessor
|
||||
from .models.vit import ViTFeatureExtractor
|
||||
from .models.yolos import YolosFeatureExtractor
|
||||
else:
|
||||
from .utils.dummy_vision_objects import *
|
||||
|
||||
@ -3831,6 +3843,12 @@ if TYPE_CHECKING:
|
||||
XLNetPreTrainedModel,
|
||||
load_tf_weights_in_xlnet,
|
||||
)
|
||||
from .models.yolos import (
|
||||
YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
YolosForObjectDetection,
|
||||
YolosModel,
|
||||
YolosPreTrainedModel,
|
||||
)
|
||||
from .models.yoso import (
|
||||
YOSO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
YosoForMaskedLM,
|
||||
|
@ -135,5 +135,6 @@ from . import (
|
||||
xlm_roberta,
|
||||
xlm_roberta_xl,
|
||||
xlnet,
|
||||
yolos,
|
||||
yoso,
|
||||
)
|
||||
|
@ -29,6 +29,7 @@ logger = logging.get_logger(__name__)
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
("yolos", "YolosConfig"),
|
||||
("tapex", "BartConfig"),
|
||||
("dpt", "DPTConfig"),
|
||||
("decision_transformer", "DecisionTransformerConfig"),
|
||||
@ -138,6 +139,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add archive maps here)
|
||||
("yolos", "YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("dpt", "DPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("glpn", "GLPN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
@ -231,6 +233,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_NAMES_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add full (and cased) model names here
|
||||
("yolos", "YOLOS"),
|
||||
("tapex", "TAPEX"),
|
||||
("dpt", "DPT"),
|
||||
("decision_transformer", "Decision Transformer"),
|
||||
|
@ -61,6 +61,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("data2vec-vision", "BeitFeatureExtractor"),
|
||||
("dpt", "DPTFeatureExtractor"),
|
||||
("glpn", "GLPNFeatureExtractor"),
|
||||
("yolos", "YolosFeatureExtractor"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
|
||||
MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
("yolos", "YolosModel"),
|
||||
("dpt", "DPTModel"),
|
||||
("decision_transformer", "DecisionTransformerModel"),
|
||||
("glpn", "GLPNModel"),
|
||||
@ -386,6 +387,7 @@ MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Object Detection mapping
|
||||
("yolos", "YolosForObjectDetection"),
|
||||
("detr", "DetrForObjectDetection"),
|
||||
]
|
||||
)
|
||||
|
@ -1865,30 +1865,31 @@ class DetrLoss(nn.Module):
|
||||
"""
|
||||
This class computes the losses for DetrForObjectDetection/DetrForSegmentation. The process happens in two steps: 1)
|
||||
we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
|
||||
of matched ground-truth / prediction (supervise class and box)
|
||||
of matched ground-truth / prediction (supervise class and box).
|
||||
|
||||
A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
|
||||
parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
|
||||
the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
|
||||
be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
|
||||
(`max_obj_id` + 1). For more details on this, check the following discussion
|
||||
https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
|
||||
|
||||
|
||||
Args:
|
||||
matcher (`DetrHungarianMatcher`):
|
||||
Module able to compute a matching between targets and proposals.
|
||||
num_classes (`int`):
|
||||
Number of object categories, omitting the special no-object category.
|
||||
eos_coef (`float`):
|
||||
Relative classification weight applied to the no-object category.
|
||||
losses (`List[str]`):
|
||||
List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
||||
"""
|
||||
|
||||
def __init__(self, matcher, num_classes, eos_coef, losses):
|
||||
"""
|
||||
Create the criterion.
|
||||
|
||||
A note on the num_classes parameter (copied from original repo in detr.py): "the naming of the `num_classes`
|
||||
parameter of the criterion is somewhat misleading. it indeed corresponds to `max_obj_id + 1`, where max_obj_id
|
||||
is the maximum id for a class in your dataset. For example, COCO has a max_obj_id of 90, so we pass
|
||||
`num_classes` to be 91. As another example, for a dataset that has a single class with id 1, you should pass
|
||||
`num_classes` to be 2 (max_obj_id + 1). For more details on this, check the following discussion
|
||||
https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
|
||||
|
||||
Parameters:
|
||||
matcher: module able to compute a matching between targets and proposals.
|
||||
num_classes: number of object categories, omitting the special no-object category.
|
||||
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
||||
eos_coef: relative classification weight applied to the no-object category.
|
||||
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.matcher = matcher
|
||||
self.num_classes = num_classes
|
||||
self.eos_coef = eos_coef
|
||||
self.losses = losses
|
||||
empty_weight = torch.ones(self.num_classes + 1)
|
||||
@ -2017,10 +2018,12 @@ class DetrLoss(nn.Module):
|
||||
"""
|
||||
This performs the loss computation.
|
||||
|
||||
Parameters:
|
||||
outputs: dict of tensors, see the output specification of the model for the format
|
||||
targets: list of dicts, such that len(targets) == batch_size.
|
||||
The expected keys in each dict depends on the losses applied, see each loss' doc
|
||||
Args:
|
||||
outputs (`dict`, *optional*):
|
||||
Dictionary of tensors, see the output specification of the model for the format.
|
||||
targets (`List[dict]`, *optional*):
|
||||
List of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
||||
|
||||
@ -2086,20 +2089,18 @@ class DetrHungarianMatcher(nn.Module):
|
||||
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
||||
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
||||
un-matched (and thus treated as non-objects).
|
||||
|
||||
Args:
|
||||
class_cost:
|
||||
The relative weight of the classification error in the matching cost.
|
||||
bbox_cost:
|
||||
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
|
||||
giou_cost:
|
||||
The relative weight of the giou loss of the bounding box in the matching cost.
|
||||
"""
|
||||
|
||||
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
|
||||
"""
|
||||
Creates the matcher.
|
||||
|
||||
Params:
|
||||
class_cost: This is the relative weight of the classification error in the matching cost
|
||||
bbox_cost:
|
||||
This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
||||
giou_cost: This is the relative weight of the giou loss of the bounding box in the matching cost
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
requires_backends(self, ["scipy"])
|
||||
|
||||
self.class_cost = class_cost
|
||||
@ -2111,25 +2112,25 @@ class DetrHungarianMatcher(nn.Module):
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, targets):
|
||||
"""
|
||||
Performs the matching.
|
||||
|
||||
Params:
|
||||
outputs: This is a dict that contains at least these entries:
|
||||
"logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
||||
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
"class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
||||
objects in the target) containing the class labels "boxes": Tensor of dim [num_target_boxes, 4]
|
||||
containing the target box coordinates
|
||||
Args:
|
||||
outputs (`dict`):
|
||||
A dictionary that contains at least these entries:
|
||||
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
||||
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
||||
targets (`List[dict]`):
|
||||
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
||||
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
||||
ground-truth
|
||||
objects in the target) containing the class labels
|
||||
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
||||
|
||||
Returns:
|
||||
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
||||
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
bs, num_queries = outputs["logits"].shape[:2]
|
||||
batch_size, num_queries = outputs["logits"].shape[:2]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
||||
@ -2152,7 +2153,7 @@ class DetrHungarianMatcher(nn.Module):
|
||||
|
||||
# Final cost matrix
|
||||
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
||||
cost_matrix = cost_matrix.view(bs, num_queries, -1).cpu()
|
||||
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
||||
|
||||
sizes = [len(v["boxes"]) for v in targets]
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
||||
@ -2175,11 +2176,12 @@ def box_area(boxes: Tensor) -> Tensor:
|
||||
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
||||
|
||||
Args:
|
||||
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
|
||||
are expected to be in (x1, y1, x2, y2) format with `0 <= x1 < x2` and `0 <= y1 < y2`.
|
||||
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
||||
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
||||
< x2` and `0 <= y1 < y2`.
|
||||
|
||||
Returns:
|
||||
area (Tensor[N]): area for each box
|
||||
`torch.FloatTensor`: a tensor containing the area for each box.
|
||||
"""
|
||||
boxes = _upcast(boxes)
|
||||
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
@ -2190,11 +2192,11 @@ def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
||||
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
||||
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
@ -2207,7 +2209,7 @@ def generalized_box_iou(boxes1, boxes2):
|
||||
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
||||
|
||||
Returns:
|
||||
a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
||||
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
@ -2242,7 +2244,6 @@ class NestedTensor(object):
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
# type: (Device) -> NestedTensor # noqa
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
|
57
src/transformers/models/yolos/__init__.py
Normal file
57
src/transformers/models/yolos/__init__.py
Normal file
@ -0,0 +1,57 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_yolos": ["YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP", "YolosConfig"],
|
||||
}
|
||||
|
||||
if is_vision_available():
|
||||
_import_structure["feature_extraction_yolos"] = ["YolosFeatureExtractor"]
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_yolos"] = [
|
||||
"YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"YolosForObjectDetection",
|
||||
"YolosModel",
|
||||
"YolosPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_yolos import YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP, YolosConfig
|
||||
|
||||
if is_vision_available():
|
||||
from .feature_extraction_yolos import YolosFeatureExtractor
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_yolos import (
|
||||
YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
YolosForObjectDetection,
|
||||
YolosModel,
|
||||
YolosPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
153
src/transformers/models/yolos/configuration_yolos.py
Normal file
153
src/transformers/models/yolos/configuration_yolos.py
Normal file
@ -0,0 +1,153 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" YOLOS model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"hustvl/yolos-small": "https://huggingface.co/hustvl/yolos-small/resolve/main/config.json",
|
||||
# See all YOLOS models at https://huggingface.co/models?filter=yolos
|
||||
}
|
||||
|
||||
|
||||
class YolosConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`YolosModel`]. It is used to instantiate a YOLOS
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the YOLOS
|
||||
[hustvl/yolos-base](https://huggingface.co/hustvl/yolos-base) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
image_size (`List[int]`, *optional*, defaults to `[512, 864]`):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to `16`):
|
||||
The size (resolution) of each patch.
|
||||
num_channels (`int`, *optional*, defaults to `3`):
|
||||
The number of input channels.
|
||||
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
num_detection_tokens (`int`, *optional*, defaults to `100`):
|
||||
The number of detection tokens.
|
||||
use_mid_position_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use the mid-layer position encodings.
|
||||
auxiliary_loss (`bool`, *optional*, defaults to `False`):
|
||||
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
class_cost (`float`, *optional*, defaults to 1):
|
||||
Relative weight of the classification error in the Hungarian matching cost.
|
||||
bbox_cost (`float`, *optional*, defaults to 5):
|
||||
Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
|
||||
giou_cost (`float`, *optional*, defaults to 2):
|
||||
Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
|
||||
bbox_loss_coefficient (`float`, *optional*, defaults to 5):
|
||||
Relative weight of the L1 bounding box loss in the object detection loss.
|
||||
giou_loss_coefficient (`float`, *optional*, defaults to 2):
|
||||
Relative weight of the generalized IoU loss in the object detection loss.
|
||||
eos_coefficient (`float`, *optional*, defaults to 0.1):
|
||||
Relative classification weight of the 'no-object' class in the object detection loss.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import YolosModel, YolosConfig
|
||||
|
||||
>>> # Initializing a YOLOS hustvl/yolos-base style configuration
|
||||
>>> configuration = YolosConfig()
|
||||
|
||||
>>> # Initializing a model from the hustvl/yolos-base style configuration
|
||||
>>> model = YolosModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
model_type = "yolos"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
image_size=[512, 864],
|
||||
patch_size=16,
|
||||
num_channels=3,
|
||||
qkv_bias=True,
|
||||
num_detection_tokens=100,
|
||||
use_mid_position_embeddings=True,
|
||||
auxiliary_loss=False,
|
||||
class_cost=1,
|
||||
bbox_cost=5,
|
||||
giou_cost=2,
|
||||
bbox_loss_coefficient=5,
|
||||
giou_loss_coefficient=2,
|
||||
eos_coefficient=0.1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.qkv_bias = qkv_bias
|
||||
self.num_detection_tokens = num_detection_tokens
|
||||
self.use_mid_position_embeddings = use_mid_position_embeddings
|
||||
self.auxiliary_loss = auxiliary_loss
|
||||
# Hungarian matcher
|
||||
self.class_cost = class_cost
|
||||
self.bbox_cost = bbox_cost
|
||||
self.giou_cost = giou_cost
|
||||
# Loss coefficients
|
||||
self.bbox_loss_coefficient = bbox_loss_coefficient
|
||||
self.giou_loss_coefficient = giou_loss_coefficient
|
||||
self.eos_coefficient = eos_coefficient
|
263
src/transformers/models/yolos/convert_yolos_to_pytorch.py
Normal file
263
src/transformers/models/yolos/convert_yolos_to_pytorch.py
Normal file
@ -0,0 +1,263 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert YOLOS checkpoints from the original repository. URL: https://github.com/hustvl/YOLOS"""
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import requests
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import YolosConfig, YolosFeatureExtractor, YolosForObjectDetection
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_yolos_config(yolos_name):
|
||||
config = YolosConfig()
|
||||
|
||||
# size of the architecture
|
||||
if "yolos_ti" in yolos_name:
|
||||
config.hidden_size = 192
|
||||
config.intermediate_size = 768
|
||||
config.num_hidden_layers = 12
|
||||
config.num_attention_heads = 3
|
||||
config.image_size = [800, 1333]
|
||||
config.use_mid_position_embeddings = False
|
||||
elif yolos_name == "yolos_s_dWr":
|
||||
config.hidden_size = 330
|
||||
config.num_hidden_layers = 14
|
||||
config.num_attention_heads = 6
|
||||
config.intermediate_size = 1320
|
||||
elif "yolos_s" in yolos_name:
|
||||
config.hidden_size = 384
|
||||
config.intermediate_size = 1536
|
||||
config.num_hidden_layers = 12
|
||||
config.num_attention_heads = 6
|
||||
elif "yolos_b" in yolos_name:
|
||||
config.image_size = [800, 1344]
|
||||
|
||||
config.num_labels = 91
|
||||
repo_id = "datasets/huggingface/label-files"
|
||||
filename = "coco-detection-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_q_k_v(state_dict, config, base_model=False):
|
||||
for i in range(config.num_hidden_layers):
|
||||
# read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
|
||||
in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :]
|
||||
state_dict[f"encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
|
||||
state_dict[f"encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||
config.hidden_size : config.hidden_size * 2, :
|
||||
]
|
||||
state_dict[f"encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
|
||||
config.hidden_size : config.hidden_size * 2
|
||||
]
|
||||
state_dict[f"encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-config.hidden_size :, :]
|
||||
state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "backbone" in name:
|
||||
name = name.replace("backbone", "vit")
|
||||
if "cls_token" in name:
|
||||
name = name.replace("cls_token", "embeddings.cls_token")
|
||||
if "det_token" in name:
|
||||
name = name.replace("det_token", "embeddings.detection_tokens")
|
||||
if "mid_pos_embed" in name:
|
||||
name = name.replace("mid_pos_embed", "encoder.mid_position_embeddings")
|
||||
if "pos_embed" in name:
|
||||
name = name.replace("pos_embed", "embeddings.position_embeddings")
|
||||
if "patch_embed.proj" in name:
|
||||
name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "encoder.layer")
|
||||
if "attn.proj" in name:
|
||||
name = name.replace("attn.proj", "attention.output.dense")
|
||||
if "attn" in name:
|
||||
name = name.replace("attn", "attention.self")
|
||||
if "norm1" in name:
|
||||
name = name.replace("norm1", "layernorm_before")
|
||||
if "norm2" in name:
|
||||
name = name.replace("norm2", "layernorm_after")
|
||||
if "mlp.fc1" in name:
|
||||
name = name.replace("mlp.fc1", "intermediate.dense")
|
||||
if "mlp.fc2" in name:
|
||||
name = name.replace("mlp.fc2", "output.dense")
|
||||
if "class_embed" in name:
|
||||
name = name.replace("class_embed", "class_labels_classifier")
|
||||
if "bbox_embed" in name:
|
||||
name = name.replace("bbox_embed", "bbox_predictor")
|
||||
if "vit.norm" in name:
|
||||
name = name.replace("vit.norm", "vit.layernorm")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, model):
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
|
||||
if "qkv" in key:
|
||||
key_split = key.split(".")
|
||||
layer_num = int(key_split[2])
|
||||
dim = model.vit.encoder.layer[layer_num].attention.attention.all_head_size
|
||||
if "weight" in key:
|
||||
orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.query.weight"] = val[:dim, :]
|
||||
orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.key.weight"] = val[
|
||||
dim : dim * 2, :
|
||||
]
|
||||
orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
|
||||
else:
|
||||
orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.query.bias"] = val[:dim]
|
||||
orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2]
|
||||
orig_state_dict[f"vit.encoder.layer.{layer_num}.attention.attention.value.bias"] = val[-dim:]
|
||||
else:
|
||||
orig_state_dict[rename_key(key)] = val
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_yolos_checkpoint(yolos_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our YOLOS structure.
|
||||
"""
|
||||
config = get_yolos_config(yolos_name)
|
||||
|
||||
# load original state_dict
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
||||
|
||||
# load 🤗 model
|
||||
model = YolosForObjectDetection(config)
|
||||
model.eval()
|
||||
new_state_dict = convert_state_dict(state_dict, model)
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
# Check outputs on an image, prepared by YolosFeatureExtractor
|
||||
size = 800 if yolos_name != "yolos_ti" else 512
|
||||
feature_extractor = YolosFeatureExtractor(format="coco_detection", size=size)
|
||||
encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
|
||||
outputs = model(**encoding)
|
||||
logits, pred_boxes = outputs.logits, outputs.pred_boxes
|
||||
|
||||
expected_slice_logits, expected_slice_boxes = None, None
|
||||
if yolos_name == "yolos_ti":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[[-39.5022, -11.9820, -17.6888], [-29.9574, -9.9769, -17.7691], [-42.3281, -20.7200, -30.6294]]
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.4021, 0.0836, 0.7979], [0.0184, 0.2609, 0.0364], [0.1781, 0.2004, 0.2095]]
|
||||
)
|
||||
elif yolos_name == "yolos_s_200_pre":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[[-24.0248, -10.3024, -14.8290], [-42.0392, -16.8200, -27.4334], [-27.2743, -11.8154, -18.7148]]
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.2559, 0.5455, 0.4706], [0.2989, 0.7279, 0.1875], [0.7732, 0.4017, 0.4462]]
|
||||
)
|
||||
elif yolos_name == "yolos_s_300_pre":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[[-36.2220, -14.4385, -23.5457], [-35.6970, -14.7583, -21.3935], [-31.5939, -13.6042, -16.8049]]
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.7614, 0.2316, 0.4728], [0.7168, 0.4495, 0.3855], [0.4996, 0.1466, 0.9996]]
|
||||
)
|
||||
elif yolos_name == "yolos_s_dWr":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[[-42.8668, -24.1049, -41.1690], [-34.7456, -14.1274, -24.9194], [-33.7898, -12.1946, -25.6495]]
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.5587, 0.2773, 0.0605], [0.5004, 0.3014, 0.9994], [0.4999, 0.1548, 0.9994]]
|
||||
)
|
||||
elif yolos_name == "yolos_base":
|
||||
expected_slice_logits = torch.tensor(
|
||||
[[-40.6064, -24.3084, -32.6447], [-55.1990, -30.7719, -35.5877], [-51.4311, -33.3507, -35.6462]]
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.5555, 0.2794, 0.0655], [0.9049, 0.2664, 0.1894], [0.9183, 0.1984, 0.1635]]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown yolos_name: {yolos_name}")
|
||||
|
||||
assert torch.allclose(logits[0, :3, :3], expected_slice_logits, atol=1e-4)
|
||||
assert torch.allclose(pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model {yolos_name} to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
model_mapping = {
|
||||
"yolos_ti": "yolos-tiny",
|
||||
"yolos_s_200_pre": "yolos-small",
|
||||
"yolos_s_300_pre": "yolos-small-300",
|
||||
"yolos_s_dWr": "yolos-small-dwr",
|
||||
"yolos_base": "yolos-base",
|
||||
}
|
||||
|
||||
print("Pushing to the hub...")
|
||||
model_name = model_mapping[yolos_name]
|
||||
feature_extractor.push_to_hub(model_name, organization="hustvl")
|
||||
model.push_to_hub(model_name, organization="hustvl")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--yolos_name",
|
||||
default="yolos_s_200_pre",
|
||||
type=str,
|
||||
help="Name of the YOLOS model you'd like to convert. Should be one of 'yolos_ti', 'yolos_s_200_pre', 'yolos_s_300_pre', 'yolos_s_dWr', 'yolos_base'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, help="Path to the original state dict (.pth file)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_yolos_checkpoint(args.yolos_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)
|
916
src/transformers/models/yolos/feature_extraction_yolos.py
Normal file
916
src/transformers/models/yolos/feature_extraction_yolos.py
Normal file
@ -0,0 +1,916 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Feature extractor class for YOLOS."""
|
||||
|
||||
import io
|
||||
import pathlib
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
|
||||
from ...utils import TensorType, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
ImageInput = Union[Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"]]
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
|
||||
def center_to_corners_format(x):
|
||||
"""
|
||||
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
|
||||
(x_0, y_0, x_1, y_1).
|
||||
"""
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.corners_to_center_format
|
||||
def corners_to_center_format(x):
|
||||
"""
|
||||
Converts a NumPy array of bounding boxes of shape (number of bounding boxes, 4) of corners format (x_0, y_0, x_1,
|
||||
y_1) to center format (center_x, center_y, width, height).
|
||||
"""
|
||||
x_transposed = x.T
|
||||
x0, y0, x1, y1 = x_transposed[0], x_transposed[1], x_transposed[2], x_transposed[3]
|
||||
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
|
||||
return np.stack(b, axis=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.masks_to_boxes
|
||||
def masks_to_boxes(masks):
|
||||
"""
|
||||
Compute the bounding boxes around the provided panoptic segmentation masks.
|
||||
|
||||
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
||||
|
||||
Returns a [N, 4] tensor, with the boxes in corner (xyxy) format.
|
||||
"""
|
||||
if masks.size == 0:
|
||||
return np.zeros((0, 4))
|
||||
|
||||
h, w = masks.shape[-2:]
|
||||
|
||||
y = np.arange(0, h, dtype=np.float32)
|
||||
x = np.arange(0, w, dtype=np.float32)
|
||||
# see https://github.com/pytorch/pytorch/issues/50276
|
||||
y, x = np.meshgrid(y, x, indexing="ij")
|
||||
|
||||
x_mask = masks * np.expand_dims(x, axis=0)
|
||||
x_max = x_mask.reshape(x_mask.shape[0], -1).max(-1)
|
||||
x = np.ma.array(x_mask, mask=~(np.array(masks, dtype=bool)))
|
||||
x_min = x.filled(fill_value=1e8)
|
||||
x_min = x_min.reshape(x_min.shape[0], -1).min(-1)
|
||||
|
||||
y_mask = masks * np.expand_dims(y, axis=0)
|
||||
y_max = y_mask.reshape(x_mask.shape[0], -1).max(-1)
|
||||
y = np.ma.array(y_mask, mask=~(np.array(masks, dtype=bool)))
|
||||
y_min = y.filled(fill_value=1e8)
|
||||
y_min = y_min.reshape(y_min.shape[0], -1).min(-1)
|
||||
|
||||
return np.stack([x_min, y_min, x_max, y_max], 1)
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.rgb_to_id
|
||||
def rgb_to_id(color):
|
||||
if isinstance(color, np.ndarray) and len(color.shape) == 3:
|
||||
if color.dtype == np.uint8:
|
||||
color = color.astype(np.int32)
|
||||
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
|
||||
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.id_to_rgb
|
||||
def id_to_rgb(id_map):
|
||||
if isinstance(id_map, np.ndarray):
|
||||
id_map_copy = id_map.copy()
|
||||
rgb_shape = tuple(list(id_map.shape) + [3])
|
||||
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
|
||||
for i in range(3):
|
||||
rgb_map[..., i] = id_map_copy % 256
|
||||
id_map_copy //= 256
|
||||
return rgb_map
|
||||
color = []
|
||||
for _ in range(3):
|
||||
color.append(id_map % 256)
|
||||
id_map //= 256
|
||||
return color
|
||||
|
||||
|
||||
class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
r"""
|
||||
Constructs a YOLOS feature extractor.
|
||||
|
||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
||||
should refer to this superclass for more information regarding those methods.
|
||||
|
||||
|
||||
Args:
|
||||
format (`str`, *optional*, defaults to `"coco_detection"`):
|
||||
Data format of the annotations. One of "coco_detection" or "coco_panoptic".
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the input to a certain `size`.
|
||||
size (`int`, *optional*, defaults to 800):
|
||||
Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a
|
||||
sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
|
||||
the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
|
||||
height / width, size)`.
|
||||
max_size (`int`, *optional*, defaults to `1333`):
|
||||
The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is
|
||||
set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to normalize the input with mean and standard deviation.
|
||||
image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
|
||||
The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
|
||||
image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
|
||||
The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
|
||||
ImageNet std.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.__init__
|
||||
def __init__(
|
||||
self,
|
||||
format="coco_detection",
|
||||
do_resize=True,
|
||||
size=800,
|
||||
max_size=1333,
|
||||
do_normalize=True,
|
||||
image_mean=None,
|
||||
image_std=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.format = self._is_valid_format(format)
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.max_size = max_size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406] # ImageNet mean
|
||||
self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] # ImageNet std
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor._is_valid_format
|
||||
def _is_valid_format(self, format):
|
||||
if format not in ["coco_detection", "coco_panoptic"]:
|
||||
raise ValueError(f"Format {format} not supported")
|
||||
return format
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.prepare
|
||||
def prepare(self, image, target, return_segmentation_masks=False, masks_path=None):
|
||||
if self.format == "coco_detection":
|
||||
image, target = self.prepare_coco_detection(image, target, return_segmentation_masks)
|
||||
return image, target
|
||||
elif self.format == "coco_panoptic":
|
||||
image, target = self.prepare_coco_panoptic(image, target, masks_path)
|
||||
return image, target
|
||||
else:
|
||||
raise ValueError(f"Format {self.format} not supported")
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.convert_coco_poly_to_mask
|
||||
def convert_coco_poly_to_mask(self, segmentations, height, width):
|
||||
|
||||
try:
|
||||
from pycocotools import mask as coco_mask
|
||||
except ImportError:
|
||||
raise ImportError("Pycocotools is not installed in your environment.")
|
||||
|
||||
masks = []
|
||||
for polygons in segmentations:
|
||||
rles = coco_mask.frPyObjects(polygons, height, width)
|
||||
mask = coco_mask.decode(rles)
|
||||
if len(mask.shape) < 3:
|
||||
mask = mask[..., None]
|
||||
mask = np.asarray(mask, dtype=np.uint8)
|
||||
mask = np.any(mask, axis=2)
|
||||
masks.append(mask)
|
||||
if masks:
|
||||
masks = np.stack(masks, axis=0)
|
||||
else:
|
||||
masks = np.zeros((0, height, width), dtype=np.uint8)
|
||||
|
||||
return masks
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.prepare_coco_detection
|
||||
def prepare_coco_detection(self, image, target, return_segmentation_masks=False):
|
||||
"""
|
||||
Convert the target in COCO format into the format expected by DETR.
|
||||
"""
|
||||
w, h = image.size
|
||||
|
||||
image_id = target["image_id"]
|
||||
image_id = np.asarray([image_id], dtype=np.int64)
|
||||
|
||||
# get all COCO annotations for the given image
|
||||
anno = target["annotations"]
|
||||
|
||||
anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
|
||||
|
||||
boxes = [obj["bbox"] for obj in anno]
|
||||
# guard against no boxes via resizing
|
||||
boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
|
||||
boxes[:, 2:] += boxes[:, :2]
|
||||
boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=w)
|
||||
boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=h)
|
||||
|
||||
classes = [obj["category_id"] for obj in anno]
|
||||
classes = np.asarray(classes, dtype=np.int64)
|
||||
|
||||
if return_segmentation_masks:
|
||||
segmentations = [obj["segmentation"] for obj in anno]
|
||||
masks = self.convert_coco_poly_to_mask(segmentations, h, w)
|
||||
|
||||
keypoints = None
|
||||
if anno and "keypoints" in anno[0]:
|
||||
keypoints = [obj["keypoints"] for obj in anno]
|
||||
keypoints = np.asarray(keypoints, dtype=np.float32)
|
||||
num_keypoints = keypoints.shape[0]
|
||||
if num_keypoints:
|
||||
keypoints = keypoints.reshape((-1, 3))
|
||||
|
||||
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
||||
boxes = boxes[keep]
|
||||
classes = classes[keep]
|
||||
if return_segmentation_masks:
|
||||
masks = masks[keep]
|
||||
if keypoints is not None:
|
||||
keypoints = keypoints[keep]
|
||||
|
||||
target = {}
|
||||
target["boxes"] = boxes
|
||||
target["class_labels"] = classes
|
||||
if return_segmentation_masks:
|
||||
target["masks"] = masks
|
||||
target["image_id"] = image_id
|
||||
if keypoints is not None:
|
||||
target["keypoints"] = keypoints
|
||||
|
||||
# for conversion to coco api
|
||||
area = np.asarray([obj["area"] for obj in anno], dtype=np.float32)
|
||||
iscrowd = np.asarray([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno], dtype=np.int64)
|
||||
target["area"] = area[keep]
|
||||
target["iscrowd"] = iscrowd[keep]
|
||||
|
||||
target["orig_size"] = np.asarray([int(h), int(w)], dtype=np.int64)
|
||||
target["size"] = np.asarray([int(h), int(w)], dtype=np.int64)
|
||||
|
||||
return image, target
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.prepare_coco_panoptic
|
||||
def prepare_coco_panoptic(self, image, target, masks_path, return_masks=True):
|
||||
w, h = image.size
|
||||
ann_info = target.copy()
|
||||
ann_path = pathlib.Path(masks_path) / ann_info["file_name"]
|
||||
|
||||
if "segments_info" in ann_info:
|
||||
masks = np.asarray(Image.open(ann_path), dtype=np.uint32)
|
||||
masks = rgb_to_id(masks)
|
||||
|
||||
ids = np.array([ann["id"] for ann in ann_info["segments_info"]])
|
||||
masks = masks == ids[:, None, None]
|
||||
masks = np.asarray(masks, dtype=np.uint8)
|
||||
|
||||
labels = np.asarray([ann["category_id"] for ann in ann_info["segments_info"]], dtype=np.int64)
|
||||
|
||||
target = {}
|
||||
target["image_id"] = np.asarray(
|
||||
[ann_info["image_id"] if "image_id" in ann_info else ann_info["id"]], dtype=np.int64
|
||||
)
|
||||
if return_masks:
|
||||
target["masks"] = masks
|
||||
target["class_labels"] = labels
|
||||
|
||||
target["boxes"] = masks_to_boxes(masks)
|
||||
|
||||
target["size"] = np.asarray([int(h), int(w)], dtype=np.int64)
|
||||
target["orig_size"] = np.asarray([int(h), int(w)], dtype=np.int64)
|
||||
if "segments_info" in ann_info:
|
||||
target["iscrowd"] = np.asarray([ann["iscrowd"] for ann in ann_info["segments_info"]], dtype=np.int64)
|
||||
target["area"] = np.asarray([ann["area"] for ann in ann_info["segments_info"]], dtype=np.float32)
|
||||
|
||||
return image, target
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor._resize
|
||||
def _resize(self, image, size, target=None, max_size=None):
|
||||
"""
|
||||
Resize the image to the given size. Size can be min_size (scalar) or (w, h) tuple. If size is an int, smaller
|
||||
edge of the image will be matched to this number.
|
||||
|
||||
If given, also resize the target accordingly.
|
||||
"""
|
||||
if not isinstance(image, Image.Image):
|
||||
image = self.to_pil_image(image)
|
||||
|
||||
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
||||
w, h = image_size
|
||||
if max_size is not None:
|
||||
min_original_size = float(min((w, h)))
|
||||
max_original_size = float(max((w, h)))
|
||||
if max_original_size / min_original_size * size > max_size:
|
||||
size = int(round(max_size * min_original_size / max_original_size))
|
||||
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return (h, w)
|
||||
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
|
||||
return (oh, ow)
|
||||
|
||||
def get_size(image_size, size, max_size=None):
|
||||
if isinstance(size, (list, tuple)):
|
||||
return size
|
||||
else:
|
||||
# size returned must be (w, h) since we use PIL to resize images
|
||||
# so we revert the tuple
|
||||
return get_size_with_aspect_ratio(image_size, size, max_size)[::-1]
|
||||
|
||||
size = get_size(image.size, size, max_size)
|
||||
rescaled_image = self.resize(image, size=size)
|
||||
|
||||
if target is None:
|
||||
return rescaled_image, None
|
||||
|
||||
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
|
||||
ratio_width, ratio_height = ratios
|
||||
|
||||
target = target.copy()
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32)
|
||||
target["boxes"] = scaled_boxes
|
||||
|
||||
if "area" in target:
|
||||
area = target["area"]
|
||||
scaled_area = area * (ratio_width * ratio_height)
|
||||
target["area"] = scaled_area
|
||||
|
||||
w, h = size
|
||||
target["size"] = np.asarray([h, w], dtype=np.int64)
|
||||
|
||||
if "masks" in target:
|
||||
# use PyTorch as current workaround
|
||||
# TODO replace by self.resize
|
||||
masks = torch.from_numpy(target["masks"][:, None]).float()
|
||||
interpolated_masks = nn.functional.interpolate(masks, size=(h, w), mode="nearest")[:, 0] > 0.5
|
||||
target["masks"] = interpolated_masks.numpy()
|
||||
|
||||
return rescaled_image, target
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor._normalize
|
||||
def _normalize(self, image, mean, std, target=None):
|
||||
"""
|
||||
Normalize the image with a certain mean and std.
|
||||
|
||||
If given, also normalize the target bounding boxes based on the size of the image.
|
||||
"""
|
||||
|
||||
image = self.normalize(image, mean=mean, std=std)
|
||||
if target is None:
|
||||
return image, None
|
||||
|
||||
target = target.copy()
|
||||
h, w = image.shape[-2:]
|
||||
|
||||
if "boxes" in target:
|
||||
boxes = target["boxes"]
|
||||
boxes = corners_to_center_format(boxes)
|
||||
boxes = boxes / np.asarray([w, h, w, h], dtype=np.float32)
|
||||
target["boxes"] = boxes
|
||||
|
||||
return image, target
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput,
|
||||
annotations: Union[List[Dict], List[List[Dict]]] = None,
|
||||
return_segmentation_masks: Optional[bool] = False,
|
||||
masks_path: Optional[pathlib.Path] = None,
|
||||
padding: Optional[bool] = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several image(s) and optional annotations. Images are by default
|
||||
padded up to the largest image in a batch.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
||||
PIL images.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||
number of channels, H and W are image height and width.
|
||||
|
||||
annotations (`Dict`, `List[Dict]`, *optional*):
|
||||
The corresponding annotations in COCO format.
|
||||
|
||||
In case [`DetrFeatureExtractor`] was initialized with `format = "coco_detection"`, the annotations for
|
||||
each image should have the following format: {'image_id': int, 'annotations': [annotation]}, with the
|
||||
annotations being a list of COCO object annotations.
|
||||
|
||||
In case [`DetrFeatureExtractor`] was initialized with `format = "coco_panoptic"`, the annotations for
|
||||
each image should have the following format: {'image_id': int, 'file_name': str, 'segments_info':
|
||||
[segment_info]} with segments_info being a list of COCO panoptic annotations.
|
||||
|
||||
return_segmentation_masks (`Dict`, `List[Dict]`, *optional*, defaults to `False`):
|
||||
Whether to also include instance segmentation masks as part of the labels in case `format =
|
||||
"coco_detection"`.
|
||||
|
||||
masks_path (`pathlib.Path`, *optional*):
|
||||
Path to the directory containing the PNG files that store the class-agnostic image segmentations. Only
|
||||
relevant in case [`DetrFeatureExtractor`] was initialized with `format = "coco_panoptic"`.
|
||||
|
||||
padding (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to pad images up to the largest image in a batch.
|
||||
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
|
||||
objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **pixel_values** -- Pixel values to be fed to a model.
|
||||
- **labels** -- Optional labels to be fed to a model (when `annotations` are provided)
|
||||
"""
|
||||
# Input type checking for clearer error
|
||||
|
||||
valid_images = False
|
||||
valid_annotations = False
|
||||
valid_masks_path = False
|
||||
|
||||
# Check that images has a valid type
|
||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
||||
valid_images = True
|
||||
elif isinstance(images, (list, tuple)):
|
||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
||||
valid_images = True
|
||||
|
||||
if not valid_images:
|
||||
raise ValueError(
|
||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(images, (list, tuple))
|
||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
||||
)
|
||||
|
||||
# Check that annotations has a valid type
|
||||
if annotations is not None:
|
||||
if not is_batched:
|
||||
if self.format == "coco_detection":
|
||||
if isinstance(annotations, dict) and "image_id" in annotations and "annotations" in annotations:
|
||||
if isinstance(annotations["annotations"], (list, tuple)):
|
||||
# an image can have no annotations
|
||||
if len(annotations["annotations"]) == 0 or isinstance(annotations["annotations"][0], dict):
|
||||
valid_annotations = True
|
||||
elif self.format == "coco_panoptic":
|
||||
if isinstance(annotations, dict) and "image_id" in annotations and "segments_info" in annotations:
|
||||
if isinstance(annotations["segments_info"], (list, tuple)):
|
||||
# an image can have no segments (?)
|
||||
if len(annotations["segments_info"]) == 0 or isinstance(
|
||||
annotations["segments_info"][0], dict
|
||||
):
|
||||
valid_annotations = True
|
||||
else:
|
||||
if isinstance(annotations, (list, tuple)):
|
||||
if len(images) != len(annotations):
|
||||
raise ValueError("There must be as many annotations as there are images")
|
||||
if isinstance(annotations[0], Dict):
|
||||
if self.format == "coco_detection":
|
||||
if isinstance(annotations[0]["annotations"], (list, tuple)):
|
||||
valid_annotations = True
|
||||
elif self.format == "coco_panoptic":
|
||||
if isinstance(annotations[0]["segments_info"], (list, tuple)):
|
||||
valid_annotations = True
|
||||
|
||||
if not valid_annotations:
|
||||
raise ValueError(
|
||||
"""
|
||||
Annotations must of type `Dict` (single image) or `List[Dict]` (batch of images). In case of object
|
||||
detection, each dictionary should contain the keys 'image_id' and 'annotations', with the latter
|
||||
being a list of annotations in COCO format. In case of panoptic segmentation, each dictionary
|
||||
should contain the keys 'file_name', 'image_id' and 'segments_info', with the latter being a list
|
||||
of annotations in COCO format.
|
||||
"""
|
||||
)
|
||||
|
||||
# Check that masks_path has a valid type
|
||||
if masks_path is not None:
|
||||
if self.format == "coco_panoptic":
|
||||
if isinstance(masks_path, pathlib.Path):
|
||||
valid_masks_path = True
|
||||
if not valid_masks_path:
|
||||
raise ValueError(
|
||||
"The path to the directory containing the mask PNG files should be provided as a `pathlib.Path` object."
|
||||
)
|
||||
|
||||
if not is_batched:
|
||||
images = [images]
|
||||
if annotations is not None:
|
||||
annotations = [annotations]
|
||||
|
||||
# prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
|
||||
if annotations is not None:
|
||||
for idx, (image, target) in enumerate(zip(images, annotations)):
|
||||
if not isinstance(image, Image.Image):
|
||||
image = self.to_pil_image(image)
|
||||
image, target = self.prepare(image, target, return_segmentation_masks, masks_path)
|
||||
images[idx] = image
|
||||
annotations[idx] = target
|
||||
|
||||
# transformations (resizing + normalization)
|
||||
if self.do_resize and self.size is not None:
|
||||
if annotations is not None:
|
||||
for idx, (image, target) in enumerate(zip(images, annotations)):
|
||||
image, target = self._resize(image=image, target=target, size=self.size, max_size=self.max_size)
|
||||
images[idx] = image
|
||||
annotations[idx] = target
|
||||
else:
|
||||
for idx, image in enumerate(images):
|
||||
images[idx] = self._resize(image=image, target=None, size=self.size, max_size=self.max_size)[0]
|
||||
|
||||
if self.do_normalize:
|
||||
if annotations is not None:
|
||||
for idx, (image, target) in enumerate(zip(images, annotations)):
|
||||
image, target = self._normalize(
|
||||
image=image, mean=self.image_mean, std=self.image_std, target=target
|
||||
)
|
||||
images[idx] = image
|
||||
annotations[idx] = target
|
||||
else:
|
||||
images = [
|
||||
self._normalize(image=image, mean=self.image_mean, std=self.image_std)[0] for image in images
|
||||
]
|
||||
|
||||
if padding:
|
||||
# pad images up to largest image in batch
|
||||
max_size = self._max_by_axis([list(image.shape) for image in images])
|
||||
c, h, w = max_size
|
||||
padded_images = []
|
||||
for image in images:
|
||||
# create padded image
|
||||
padded_image = np.zeros((c, h, w), dtype=np.float32)
|
||||
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
|
||||
padded_images.append(padded_image)
|
||||
images = padded_images
|
||||
|
||||
# return as BatchFeature
|
||||
data = {}
|
||||
data["pixel_values"] = images
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
if annotations is not None:
|
||||
# Convert to TensorType
|
||||
tensor_type = return_tensors
|
||||
if not isinstance(tensor_type, TensorType):
|
||||
tensor_type = TensorType(tensor_type)
|
||||
|
||||
if not tensor_type == TensorType.PYTORCH:
|
||||
raise ValueError("Only PyTorch is supported for the moment.")
|
||||
else:
|
||||
if not is_torch_available():
|
||||
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
||||
|
||||
encoded_inputs["labels"] = [
|
||||
{k: torch.from_numpy(v) for k, v in target.items()} for target in annotations
|
||||
]
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor._max_by_axis
|
||||
def _max_by_axis(self, the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
def pad(self, pixel_values_list: List["torch.Tensor"], return_tensors: Optional[Union[str, TensorType]] = None):
|
||||
"""
|
||||
Pad images up to the largest image in a batch.
|
||||
|
||||
Args:
|
||||
pixel_values_list (`List[torch.Tensor]`):
|
||||
List of images (pixel values) to be padded. Each image should be a tensor of shape (C, H, W).
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
|
||||
objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following field:
|
||||
|
||||
- **pixel_values** -- Pixel values to be fed to a model.
|
||||
|
||||
"""
|
||||
|
||||
max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list])
|
||||
c, h, w = max_size
|
||||
padded_images = []
|
||||
for image in pixel_values_list:
|
||||
# create padded image
|
||||
padded_image = np.zeros((c, h, w), dtype=np.float32)
|
||||
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
|
||||
padded_images.append(padded_image)
|
||||
|
||||
# return as BatchFeature
|
||||
data = {"pixel_values": padded_images}
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process
|
||||
def post_process(self, outputs, target_sizes):
|
||||
"""
|
||||
Converts the output of [`DetrForObjectDetection`] into the format expected by the COCO api. Only supports
|
||||
PyTorch.
|
||||
|
||||
Args:
|
||||
outputs ([`DetrObjectDetectionOutput`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
||||
Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
|
||||
image size (before any data augmentation). For visualization, this should be the image size after data
|
||||
augment, but before padding.
|
||||
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
||||
in the batch as predicted by the model.
|
||||
"""
|
||||
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
||||
|
||||
if len(out_logits) != len(target_sizes):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
|
||||
if target_sizes.shape[1] != 2:
|
||||
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
||||
|
||||
prob = nn.functional.softmax(out_logits, -1)
|
||||
scores, labels = prob[..., :-1].max(-1)
|
||||
|
||||
# convert to [x0, y0, x1, y1] format
|
||||
boxes = center_to_corners_format(out_bbox)
|
||||
# and from relative [0, 1] to absolute [0, height] coordinates
|
||||
img_h, img_w = target_sizes.unbind(1)
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
||||
boxes = boxes * scale_fct[:, None, :]
|
||||
|
||||
results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
|
||||
|
||||
return results
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_segmentation
|
||||
def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5):
|
||||
"""
|
||||
Converts the output of [`DetrForSegmentation`] into image segmentation predictions. Only supports PyTorch.
|
||||
|
||||
Parameters:
|
||||
outputs ([`DetrSegmentationOutput`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`):
|
||||
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction.
|
||||
threshold (`float`, *optional*, defaults to 0.9):
|
||||
Threshold to use to filter out queries.
|
||||
mask_threshold (`float`, *optional*, defaults to 0.5):
|
||||
Threshold to use when turning the predicted masks into binary values.
|
||||
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
|
||||
in the batch as predicted by the model.
|
||||
"""
|
||||
out_logits, raw_masks = outputs.logits, outputs.pred_masks
|
||||
preds = []
|
||||
|
||||
def to_tuple(tup):
|
||||
if isinstance(tup, tuple):
|
||||
return tup
|
||||
return tuple(tup.cpu().tolist())
|
||||
|
||||
for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
|
||||
# we filter empty queries and detection below threshold
|
||||
scores, labels = cur_logits.softmax(-1).max(-1)
|
||||
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold)
|
||||
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
|
||||
cur_scores = cur_scores[keep]
|
||||
cur_classes = cur_classes[keep]
|
||||
cur_masks = cur_masks[keep]
|
||||
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
||||
cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
|
||||
|
||||
predictions = {"scores": cur_scores, "labels": cur_classes, "masks": cur_masks}
|
||||
preds.append(predictions)
|
||||
return preds
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_instance
|
||||
def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
|
||||
"""
|
||||
Converts the output of [`DetrForSegmentation`] into actual instance segmentation predictions. Only supports
|
||||
PyTorch.
|
||||
|
||||
Args:
|
||||
results (`List[Dict]`):
|
||||
Results list obtained by [`~DetrFeatureExtractor.post_process`], to which "masks" results will be
|
||||
added.
|
||||
outputs ([`DetrSegmentationOutput`]):
|
||||
Raw outputs of the model.
|
||||
orig_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
||||
Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
|
||||
image size (before any data augmentation).
|
||||
max_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
||||
Tensor containing the maximum size (h, w) of each image of the batch. For evaluation, this must be the
|
||||
original image size (before any data augmentation).
|
||||
threshold (`float`, *optional*, defaults to 0.5):
|
||||
Threshold to use when turning the predicted masks into binary values.
|
||||
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
|
||||
image in the batch as predicted by the model.
|
||||
"""
|
||||
|
||||
if len(orig_target_sizes) != len(max_target_sizes):
|
||||
raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
|
||||
max_h, max_w = max_target_sizes.max(0)[0].tolist()
|
||||
outputs_masks = outputs.pred_masks.squeeze(2)
|
||||
outputs_masks = nn.functional.interpolate(
|
||||
outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
|
||||
)
|
||||
outputs_masks = (outputs_masks.sigmoid() > threshold).cpu()
|
||||
|
||||
for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
|
||||
img_h, img_w = t[0], t[1]
|
||||
results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
|
||||
results[i]["masks"] = nn.functional.interpolate(
|
||||
results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
|
||||
).byte()
|
||||
|
||||
return results
|
||||
|
||||
# Copied from transformers.models.detr.feature_extraction_detr.DetrFeatureExtractor.post_process_panoptic
|
||||
def post_process_panoptic(self, outputs, processed_sizes, target_sizes=None, is_thing_map=None, threshold=0.85):
|
||||
"""
|
||||
Converts the output of [`DetrForSegmentation`] into actual panoptic predictions. Only supports PyTorch.
|
||||
|
||||
Parameters:
|
||||
outputs ([`DetrSegmentationOutput`]):
|
||||
Raw outputs of the model.
|
||||
processed_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`):
|
||||
Torch Tensor (or list) containing the size (h, w) of each image of the batch, i.e. the size after data
|
||||
augmentation but before batching.
|
||||
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
|
||||
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to
|
||||
None, it will default to the `processed_sizes`.
|
||||
is_thing_map (`torch.Tensor` of shape `(batch_size, 2)`, *optional*):
|
||||
Dictionary mapping class indices to either True or False, depending on whether or not they are a thing.
|
||||
If not set, defaults to the `is_thing_map` of COCO panoptic.
|
||||
threshold (`float`, *optional*, defaults to 0.85):
|
||||
Threshold to use to filter out queries.
|
||||
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
|
||||
an image in the batch as predicted by the model.
|
||||
"""
|
||||
if target_sizes is None:
|
||||
target_sizes = processed_sizes
|
||||
if len(processed_sizes) != len(target_sizes):
|
||||
raise ValueError("Make sure to pass in as many processed_sizes as target_sizes")
|
||||
|
||||
if is_thing_map is None:
|
||||
# default to is_thing_map of COCO panoptic
|
||||
is_thing_map = {i: i <= 90 for i in range(201)}
|
||||
|
||||
out_logits, raw_masks, raw_boxes = outputs.logits, outputs.pred_masks, outputs.pred_boxes
|
||||
if not len(out_logits) == len(raw_masks) == len(target_sizes):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
|
||||
)
|
||||
preds = []
|
||||
|
||||
def to_tuple(tup):
|
||||
if isinstance(tup, tuple):
|
||||
return tup
|
||||
return tuple(tup.cpu().tolist())
|
||||
|
||||
for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
|
||||
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
|
||||
):
|
||||
# we filter empty queries and detection below threshold
|
||||
scores, labels = cur_logits.softmax(-1).max(-1)
|
||||
keep = labels.ne(outputs.logits.shape[-1] - 1) & (scores > threshold)
|
||||
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
|
||||
cur_scores = cur_scores[keep]
|
||||
cur_classes = cur_classes[keep]
|
||||
cur_masks = cur_masks[keep]
|
||||
cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
|
||||
cur_boxes = center_to_corners_format(cur_boxes[keep])
|
||||
|
||||
h, w = cur_masks.shape[-2:]
|
||||
if len(cur_boxes) != len(cur_classes):
|
||||
raise ValueError("Not as many boxes as there are classes")
|
||||
|
||||
# It may be that we have several predicted masks for the same stuff class.
|
||||
# In the following, we track the list of masks ids for each stuff class (they are merged later on)
|
||||
cur_masks = cur_masks.flatten(1)
|
||||
stuff_equiv_classes = defaultdict(lambda: [])
|
||||
for k, label in enumerate(cur_classes):
|
||||
if not is_thing_map[label.item()]:
|
||||
stuff_equiv_classes[label.item()].append(k)
|
||||
|
||||
def get_ids_area(masks, scores, dedup=False):
|
||||
# This helper function creates the final panoptic segmentation image
|
||||
# It also returns the area of the masks that appears on the image
|
||||
|
||||
m_id = masks.transpose(0, 1).softmax(-1)
|
||||
|
||||
if m_id.shape[-1] == 0:
|
||||
# We didn't detect any mask :(
|
||||
m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
|
||||
else:
|
||||
m_id = m_id.argmax(-1).view(h, w)
|
||||
|
||||
if dedup:
|
||||
# Merge the masks corresponding to the same stuff class
|
||||
for equiv in stuff_equiv_classes.values():
|
||||
if len(equiv) > 1:
|
||||
for eq_id in equiv:
|
||||
m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
|
||||
|
||||
final_h, final_w = to_tuple(target_size)
|
||||
|
||||
seg_img = Image.fromarray(id_to_rgb(m_id.view(h, w).cpu().numpy()))
|
||||
seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST)
|
||||
|
||||
np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes()))
|
||||
np_seg_img = np_seg_img.view(final_h, final_w, 3)
|
||||
np_seg_img = np_seg_img.numpy()
|
||||
|
||||
m_id = torch.from_numpy(rgb_to_id(np_seg_img))
|
||||
|
||||
area = []
|
||||
for i in range(len(scores)):
|
||||
area.append(m_id.eq(i).sum().item())
|
||||
return area, seg_img
|
||||
|
||||
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
|
||||
if cur_classes.numel() > 0:
|
||||
# We know filter empty masks as long as we find some
|
||||
while True:
|
||||
filtered_small = torch.as_tensor(
|
||||
[area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device
|
||||
)
|
||||
if filtered_small.any().item():
|
||||
cur_scores = cur_scores[~filtered_small]
|
||||
cur_classes = cur_classes[~filtered_small]
|
||||
cur_masks = cur_masks[~filtered_small]
|
||||
area, seg_img = get_ids_area(cur_masks, cur_scores)
|
||||
else:
|
||||
break
|
||||
|
||||
else:
|
||||
cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)
|
||||
|
||||
segments_info = []
|
||||
for i, a in enumerate(area):
|
||||
cat = cur_classes[i].item()
|
||||
segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
|
||||
del cur_classes
|
||||
|
||||
with io.BytesIO() as out:
|
||||
seg_img.save(out, format="PNG")
|
||||
predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
|
||||
preds.append(predictions)
|
||||
return preds
|
1324
src/transformers/models/yolos/modeling_yolos.py
Executable file
1324
src/transformers/models/yolos/modeling_yolos.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -4697,6 +4697,30 @@ def load_tf_weights_in_xlnet(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_xlnet, ["torch"])
|
||||
|
||||
|
||||
YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class YolosForObjectDetection(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class YolosModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class YolosPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
YOSO_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -141,3 +141,10 @@ class ViTFeatureExtractor(metaclass=DummyObject):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class YolosFeatureExtractor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
0
tests/yolos/__init__.py
Normal file
0
tests/yolos/__init__.py
Normal file
336
tests/yolos/test_feature_extraction_yolos.py
Normal file
336
tests/yolos/test_feature_extraction_yolos.py
Normal file
@ -0,0 +1,336 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision, slow
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import YolosFeatureExtractor
|
||||
|
||||
|
||||
class YolosFeatureExtractionTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=18,
|
||||
max_size=1333, # by setting max_size > max_resolution we're effectively not testing this :p
|
||||
do_normalize=True,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.max_size = max_size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"max_size": self.max_size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
}
|
||||
|
||||
def get_expected_values(self, image_inputs, batched=False):
|
||||
"""
|
||||
This function computes the expected height and width when providing images to YolosFeatureExtractor,
|
||||
assuming do_resize is set to True with a scalar size.
|
||||
"""
|
||||
if not batched:
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
if w < h:
|
||||
expected_height = int(self.size * h / w)
|
||||
expected_width = self.size
|
||||
elif w > h:
|
||||
expected_height = self.size
|
||||
expected_width = int(self.size * w / h)
|
||||
else:
|
||||
expected_height = self.size
|
||||
expected_width = self.size
|
||||
|
||||
else:
|
||||
expected_values = []
|
||||
for image in image_inputs:
|
||||
expected_height, expected_width = self.get_expected_values([image])
|
||||
expected_values.append((expected_height, expected_width))
|
||||
expected_height = max(expected_values, key=lambda item: item[0])[0]
|
||||
expected_width = max(expected_values, key=lambda item: item[1])[1]
|
||||
|
||||
return expected_height, expected_width
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class YolosFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
|
||||
|
||||
feature_extraction_class = YolosFeatureExtractor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.feature_extract_tester = YolosFeatureExtractionTester(self)
|
||||
|
||||
@property
|
||||
def feat_extract_dict(self):
|
||||
return self.feature_extract_tester.prepare_feat_extract_dict()
|
||||
|
||||
def test_feat_extract_properties(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
self.assertTrue(hasattr(feature_extractor, "image_mean"))
|
||||
self.assertTrue(hasattr(feature_extractor, "image_std"))
|
||||
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
|
||||
self.assertTrue(hasattr(feature_extractor, "do_resize"))
|
||||
self.assertTrue(hasattr(feature_extractor, "size"))
|
||||
self.assertTrue(hasattr(feature_extractor, "max_size"))
|
||||
|
||||
def test_batch_feature(self):
|
||||
pass
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random PIL images
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs)
|
||||
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(1, self.feature_extract_tester.num_channels, expected_height, expected_width),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True)
|
||||
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
expected_height,
|
||||
expected_width,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs)
|
||||
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(1, self.feature_extract_tester.num_channels, expected_height, expected_width),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True)
|
||||
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
expected_height,
|
||||
expected_width,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs)
|
||||
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(1, self.feature_extract_tester.num_channels, expected_height, expected_width),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True)
|
||||
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
expected_height,
|
||||
expected_width,
|
||||
),
|
||||
)
|
||||
|
||||
def test_equivalence_padding(self):
|
||||
# Initialize feature_extractors
|
||||
feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test whether the method "pad" and calling the feature extractor return the same tensors
|
||||
encoded_images_with_method = feature_extractor_1.pad(image_inputs, return_tensors="pt")
|
||||
encoded_images = feature_extractor_2(image_inputs, return_tensors="pt")
|
||||
|
||||
assert torch.allclose(encoded_images_with_method["pixel_values"], encoded_images["pixel_values"], atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_call_pytorch_with_coco_detection_annotations(self):
|
||||
# prepare image and target
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f:
|
||||
target = json.loads(f.read())
|
||||
|
||||
target = {"image_id": 39769, "annotations": target}
|
||||
|
||||
# encode them
|
||||
feature_extractor = YolosFeatureExtractor.from_pretrained("hustvl/yolos-small")
|
||||
encoding = feature_extractor(images=image, annotations=target, return_tensors="pt")
|
||||
|
||||
# verify pixel values
|
||||
expected_shape = torch.Size([1, 3, 800, 1066])
|
||||
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([0.2796, 0.3138, 0.3481])
|
||||
assert torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-4)
|
||||
|
||||
# verify area
|
||||
expected_area = torch.tensor([5887.9600, 11250.2061, 489353.8438, 837122.7500, 147967.5156, 165732.3438])
|
||||
assert torch.allclose(encoding["labels"][0]["area"], expected_area)
|
||||
# verify boxes
|
||||
expected_boxes_shape = torch.Size([6, 4])
|
||||
self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
|
||||
expected_boxes_slice = torch.tensor([0.5503, 0.2765, 0.0604, 0.2215])
|
||||
assert torch.allclose(encoding["labels"][0]["boxes"][0], expected_boxes_slice, atol=1e-3)
|
||||
# verify image_id
|
||||
expected_image_id = torch.tensor([39769])
|
||||
assert torch.allclose(encoding["labels"][0]["image_id"], expected_image_id)
|
||||
# verify is_crowd
|
||||
expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
|
||||
assert torch.allclose(encoding["labels"][0]["iscrowd"], expected_is_crowd)
|
||||
# verify class_labels
|
||||
expected_class_labels = torch.tensor([75, 75, 63, 65, 17, 17])
|
||||
assert torch.allclose(encoding["labels"][0]["class_labels"], expected_class_labels)
|
||||
# verify orig_size
|
||||
expected_orig_size = torch.tensor([480, 640])
|
||||
assert torch.allclose(encoding["labels"][0]["orig_size"], expected_orig_size)
|
||||
# verify size
|
||||
expected_size = torch.tensor([800, 1066])
|
||||
assert torch.allclose(encoding["labels"][0]["size"], expected_size)
|
||||
|
||||
@slow
|
||||
def test_call_pytorch_with_coco_panoptic_annotations(self):
|
||||
# prepare image, target and masks_path
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
with open("./tests/fixtures/tests_samples/COCO/coco_panoptic_annotations.txt", "r") as f:
|
||||
target = json.loads(f.read())
|
||||
|
||||
target = {"file_name": "000000039769.png", "image_id": 39769, "segments_info": target}
|
||||
|
||||
masks_path = pathlib.Path("./tests/fixtures/tests_samples/COCO/coco_panoptic")
|
||||
|
||||
# encode them
|
||||
feature_extractor = YolosFeatureExtractor(format="coco_panoptic")
|
||||
encoding = feature_extractor(images=image, annotations=target, masks_path=masks_path, return_tensors="pt")
|
||||
|
||||
# verify pixel values
|
||||
expected_shape = torch.Size([1, 3, 800, 1066])
|
||||
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([0.2796, 0.3138, 0.3481])
|
||||
assert torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-4)
|
||||
|
||||
# verify area
|
||||
expected_area = torch.tensor([147979.6875, 165527.0469, 484638.5938, 11292.9375, 5879.6562, 7634.1147])
|
||||
assert torch.allclose(encoding["labels"][0]["area"], expected_area)
|
||||
# verify boxes
|
||||
expected_boxes_shape = torch.Size([6, 4])
|
||||
self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
|
||||
expected_boxes_slice = torch.tensor([0.2625, 0.5437, 0.4688, 0.8625])
|
||||
assert torch.allclose(encoding["labels"][0]["boxes"][0], expected_boxes_slice, atol=1e-3)
|
||||
# verify image_id
|
||||
expected_image_id = torch.tensor([39769])
|
||||
assert torch.allclose(encoding["labels"][0]["image_id"], expected_image_id)
|
||||
# verify is_crowd
|
||||
expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
|
||||
assert torch.allclose(encoding["labels"][0]["iscrowd"], expected_is_crowd)
|
||||
# verify class_labels
|
||||
expected_class_labels = torch.tensor([17, 17, 63, 75, 75, 93])
|
||||
assert torch.allclose(encoding["labels"][0]["class_labels"], expected_class_labels)
|
||||
# verify masks
|
||||
expected_masks_sum = 822338
|
||||
self.assertEqual(encoding["labels"][0]["masks"].sum().item(), expected_masks_sum)
|
||||
# verify orig_size
|
||||
expected_orig_size = torch.tensor([480, 640])
|
||||
assert torch.allclose(encoding["labels"][0]["orig_size"], expected_orig_size)
|
||||
# verify size
|
||||
expected_size = torch.tensor([800, 1066])
|
||||
assert torch.allclose(encoding["labels"][0]["size"], expected_size)
|
373
tests/yolos/test_modeling_yolos.py
Normal file
373
tests/yolos/test_modeling_yolos.py
Normal file
@ -0,0 +1,373 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch YOLOS model. """
|
||||
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import YolosConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_common import ModelTesterMixin, floats_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import YolosForObjectDetection, YolosModel
|
||||
from transformers.models.yolos.modeling_yolos import YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
class YolosModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
image_size=[30, 30],
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
type_sequence_label_size=10,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
n_targets=8,
|
||||
num_detection_tokens=10,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.scope = scope
|
||||
self.n_targets = n_targets
|
||||
self.num_detection_tokens = num_detection_tokens
|
||||
# we set the expected sequence length (which is used in several tests)
|
||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens
|
||||
image_size = to_2tuple(self.image_size)
|
||||
patch_size = to_2tuple(self.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.expected_seq_len = num_patches + 1 + self.num_detection_tokens
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size[0], self.image_size[1]])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
# labels is a list of Dict (each Dict being the labels for a given example in the batch)
|
||||
labels = []
|
||||
for i in range(self.batch_size):
|
||||
target = {}
|
||||
target["class_labels"] = torch.randint(
|
||||
high=self.num_labels, size=(self.n_targets,), device=torch_device
|
||||
)
|
||||
target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device)
|
||||
labels.append(target)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def get_config(self):
|
||||
return YolosConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
num_detection_tokens=self.num_detection_tokens,
|
||||
num_labels=self.num_labels,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
model = YolosModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.expected_seq_len, self.hidden_size)
|
||||
)
|
||||
|
||||
def create_and_check_for_object_detection(self, config, pixel_values, labels):
|
||||
model = YolosForObjectDetection(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(pixel_values=pixel_values)
|
||||
result = model(pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_detection_tokens, self.num_labels + 1))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_detection_tokens, 4))
|
||||
|
||||
result = model(pixel_values=pixel_values, labels=labels)
|
||||
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_detection_tokens, self.num_labels + 1))
|
||||
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_detection_tokens, 4))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class YolosModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as YOLOS does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (YolosModel, YolosForObjectDetection) if is_torch_available() else ()
|
||||
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torchscript = False
|
||||
|
||||
# special case for head model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
if model_class.__name__ == "YolosForObjectDetection":
|
||||
labels = []
|
||||
for i in range(self.model_tester.batch_size):
|
||||
target = {}
|
||||
target["class_labels"] = torch.ones(
|
||||
size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long
|
||||
)
|
||||
target["boxes"] = torch.ones(
|
||||
self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float
|
||||
)
|
||||
labels.append(target)
|
||||
inputs_dict["labels"] = labels
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = YolosModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=YolosConfig, has_text_modality=False, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
# YOLOS does not use inputs_embeds
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
# in YOLOS, the seq_len is different
|
||||
seq_len = self.model_tester.expected_seq_len
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, seq_len, seq_len],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, seq_len, seq_len],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
# YOLOS has a different seq_length
|
||||
seq_length = self.model_tester.expected_seq_len
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_for_object_detection(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_object_detection(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = YolosModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class YolosModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return AutoFeatureExtractor.from_pretrained("hustvl/yolos-small") if is_vision_available() else None
|
||||
|
||||
@slow
|
||||
def test_inference_object_detection_head(self):
|
||||
model = YolosForObjectDetection.from_pretrained("hustvl/yolos-small").to(torch_device)
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(inputs.pixel_values)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 100, 92))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice_logits = torch.tensor(
|
||||
[[-24.0248, -10.3024, -14.8290], [-42.0392, -16.8200, -27.4334], [-27.2743, -11.8154, -18.7148]],
|
||||
device=torch_device,
|
||||
)
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.2559, 0.5455, 0.4706], [0.2989, 0.7279, 0.1875], [0.7732, 0.4017, 0.4462]], device=torch_device
|
||||
)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice_logits, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
@ -18,6 +18,7 @@ src/transformers/models/big_bird/modeling_big_bird.py
|
||||
src/transformers/models/blenderbot/modeling_blenderbot.py
|
||||
src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
|
||||
src/transformers/models/convnext/modeling_convnext.py
|
||||
src/transformers/models/ctrl/modeling_ctrl.py
|
||||
src/transformers/models/data2vec/modeling_data2vec_audio.py
|
||||
src/transformers/models/data2vec/modeling_data2vec_vision.py
|
||||
src/transformers/models/deit/modeling_deit.py
|
||||
@ -58,5 +59,5 @@ src/transformers/models/vit_mae/modeling_vit_mae.py
|
||||
src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
||||
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
|
||||
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
|
||||
src/transformers/models/wavlm/modeling_wavlm.py
|
||||
src/transformers/models/ctrl/modeling_ctrl.py
|
||||
src/transformers/models/wavlm/modeling_wavlm.py
|
||||
src/transformers/models/yolos/modeling_yolos.py
|
||||
|
Loading…
Reference in New Issue
Block a user