mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add OWL-ViT model for zero-shot object detection (#17938)
* add owlvit model skeleton * add class and box predictor heads * convert modified flax clip to pytorch * fix box and class predictors * add OwlViTImageTextEmbedder * convert class and box head checkpoints * convert image text embedder checkpoints * add object detection head * fix bugs * update conversion script * update conversion script * fix q,v,k,out weight conversion conversion * add owlvit object detection output * fix bug in image embedder * fix bugs in text embedder * fix positional embeddings * fix bug in inference mode vision pooling * update docs, init tokenizer and processor files * support batch processing * add OwlViTProcessor * remove merge conflicts * readd owlvit imports * fix bug in OwlViTProcessor imports * fix bugs in processor * update docs * fix bugs in processor * update owlvit docs * add OwlViTFeatureExtractor * style changes, add postprocess method to feature extractor * add feature extractor and processor tests * add object detection tests * update conversion script * update config paths * update config paths * fix configuration paths and bugs * fix bugs in OwlViT tests * add import checks to processor * fix docs and minor issues * fix docs and minor issues * fix bugs and issues * fix bugs and issues * fix bugs and issues * fix bugs and issues * update docs and examples * fix bugs and issues * update conversion script, fix positional embeddings * process 2D input ids, update tests * fix style and quality issues * update docs * update docs and imports * update OWL-ViT index.md * fix bug in OwlViT feature ext tests * fix code examples, return_dict by default * return_dict by default * minor fixes, add tests to processor * small fixes * add output_attentions arg to main model * fix bugs * remove output_hidden_states arg from main model * update self.config variables * add option to return last_hidden_states * fix bug in config variables * fix copied from statements * fix small issues and bugs * fix bugs * fix bugs, support greyscale images * run fixup * update repo name * merge OwlViTImageTextEmbedder with obj detection head * fix merge conflict * fix merge conflict * make fixup * fix bugs * fix bugs * add additional processor test
This commit is contained in:
parent
99eb9b523f
commit
12d66b4701
@ -332,6 +332,7 @@ Current number of checkpoints: ** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
|
||||
1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh.
|
||||
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
|
||||
1. **[OWL-ViT](https://huggingface.co/docs/transformers/main/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
|
||||
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
|
||||
1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (from Deepmind) released with the paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) by Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira.
|
||||
1. **[PhoBERT](https://huggingface.co/docs/transformers/model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen.
|
||||
|
@ -288,6 +288,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
||||
1. **[NLLB](https://huggingface.co/docs/transformers/main/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
|
||||
1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh.
|
||||
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
|
||||
1. **[OWL-ViT](https://huggingface.co/docs/transformers/main/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
|
||||
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
|
||||
1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (from Deepmind) released with the paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) by Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira.
|
||||
1. **[PhoBERT](https://huggingface.co/docs/transformers/model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen.
|
||||
|
@ -312,6 +312,7 @@ conda install -c huggingface transformers
|
||||
1. **[NLLB](https://huggingface.co/docs/transformers/main/model_doc/nllb)** (来自 Meta) 伴随论文 [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) 由 the NLLB team 发布。
|
||||
1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (来自 the University of Wisconsin - Madison) 伴随论文 [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) 由 Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh 发布。
|
||||
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (来自 Meta AI) 伴随论文 [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) 由 Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al 发布。
|
||||
1. **[OWL-ViT](https://huggingface.co/docs/transformers/main/model_doc/owlvit)** (来自 Google AI) 伴随论文 [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 由 Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby 发布。
|
||||
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (来自 Google) 伴随论文 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 由 Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 发布。
|
||||
1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (来自 Deepmind) 伴随论文 [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) 由 Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira 发布。
|
||||
1. **[PhoBERT](https://huggingface.co/docs/transformers/model_doc/phobert)** (来自 VinAI Research) 伴随论文 [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) 由 Dat Quoc Nguyen and Anh Tuan Nguyen 发布。
|
||||
|
@ -324,6 +324,7 @@ conda install -c huggingface transformers
|
||||
1. **[NLLB](https://huggingface.co/docs/transformers/main/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
|
||||
1. **[Nyströmformer](https://huggingface.co/docs/transformers/model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh.
|
||||
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
|
||||
1. **[OWL-ViT](https://huggingface.co/docs/transformers/main/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
|
||||
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
|
||||
1. **[Perceiver IO](https://huggingface.co/docs/transformers/model_doc/perceiver)** (from Deepmind) released with the paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) by Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira.
|
||||
1. **[PhoBERT](https://huggingface.co/docs/transformers/model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen.
|
||||
|
@ -326,6 +326,8 @@
|
||||
title: Nyströmformer
|
||||
- local: model_doc/opt
|
||||
title: OPT
|
||||
- local: model_doc/owlvit
|
||||
title: OWL-ViT
|
||||
- local: model_doc/pegasus
|
||||
title: Pegasus
|
||||
- local: model_doc/perceiver
|
||||
|
@ -130,6 +130,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret
|
||||
1. **[NLLB](model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
|
||||
1. **[Nyströmformer](model_doc/nystromformer)** (from the University of Wisconsin - Madison) released with the paper [Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention](https://arxiv.org/abs/2102.03902) by Yunyang Xiong, Zhanpeng Zeng, Rudrasis Chakraborty, Mingxing Tan, Glenn Fung, Yin Li, Vikas Singh.
|
||||
1. **[OPT](master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
|
||||
1. **[OWL-ViT](model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
|
||||
1. **[Pegasus](model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
|
||||
1. **[Perceiver IO](model_doc/perceiver)** (from Deepmind) released with the paper [Perceiver IO: A General Architecture for Structured Inputs & Outputs](https://arxiv.org/abs/2107.14795) by Andrew Jaegle, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, Daniel Zoran, Andrew Brock, Evan Shelhamer, Olivier Hénaff, Matthew M. Botvinick, Andrew Zisserman, Oriol Vinyals, João Carreira.
|
||||
1. **[PhoBERT](model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen.
|
||||
@ -263,6 +264,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| OPT | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
| OWL-ViT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Pegasus | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Perceiver | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| PLBart | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
101
docs/source/en/model_doc/owlvit.mdx
Normal file
101
docs/source/en/model_doc/owlvit.mdx
Normal file
@ -0,0 +1,101 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# OWL-ViT
|
||||
|
||||
## Overview
|
||||
|
||||
The OWL-ViT (short for Vision Transformer for Open-World Localization) was proposed in [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. OWL-ViT is an open-vocabulary object detection network trained on a variety of (image, text) pairs. It can be used to query an image with one or multiple text queries to search for and detect target objects described in text.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Combining simple architectures with large-scale pre-training has led to massive improvements in image classification. For object detection, pre-training and scaling approaches are less well established, especially in the long-tailed and open-vocabulary setting, where training data is relatively scarce. In this paper, we propose a strong recipe for transferring image-text models to open-vocabulary object detection. We use a standard Vision Transformer architecture with minimal modifications, contrastive image-text pre-training, and end-to-end detection fine-tuning. Our analysis of the scaling properties of this setup shows that increasing image-level pre-training and model size yield consistent improvements on the downstream detection task. We provide the adaptation strategies and regularizations needed to attain very strong performance on zero-shot text-conditioned and one-shot image-conditioned object detection. Code and models are available on GitHub.*
|
||||
|
||||
## Usage
|
||||
|
||||
OWL-ViT is a zero-shot text-conditioned object detection model. OWL-ViT uses [CLIP](clip) as its multi-modal backbone, with a ViT-like Transformer to get visual features and a causal language model to get the text features. To use CLIP for detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a lightweight classification and box head to each transformer output token. Open-vocabulary classification is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image can be used to perform zero-shot text-conditioned object detection.
|
||||
|
||||
[`OwlViTFeatureExtractor`] can be used to resize (or rescale) and normalize images for the model and [`CLIPTokenizer`] is used to encode the text. [`OwlViTProcessor`] wraps [`OwlViTFeatureExtractor`] and [`CLIPTokenizer`] into a single instance to both encode the text and prepare the images. The following example shows how to perform object detection using [`OwlViTProcessor`] and [`OwlViTForObjectDetection`].
|
||||
|
||||
|
||||
```python
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
>>> import torch
|
||||
|
||||
>>> from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
||||
|
||||
>>> processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
|
||||
>>> model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs["logits"] # Prediction logits of shape [batch_size, num_patches, num_max_text_queries]
|
||||
>>> boxes = outputs["pred_boxes"] # Object box boundaries of shape [batch_size, num_patches, 4]
|
||||
|
||||
>>> batch_size = boxes.shape[0]
|
||||
>>> for i in range(batch_size): # Loop over sets of images and text queries
|
||||
... boxes = outputs["pred_boxes"][i]
|
||||
... logits = torch.max(outputs["logits"][i], dim=-1)
|
||||
... scores = torch.sigmoid(logits.values)
|
||||
... labels = logits.indices
|
||||
```
|
||||
|
||||
This model was contributed by [adirik](https://huggingface.co/adirik). The original code can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit).
|
||||
|
||||
## OwlViTConfig
|
||||
|
||||
[[autodoc]] OwlViTConfig
|
||||
- from_text_vision_configs
|
||||
|
||||
## OwlViTTextConfig
|
||||
|
||||
[[autodoc]] OwlViTTextConfig
|
||||
|
||||
## OwlViTVisionConfig
|
||||
|
||||
[[autodoc]] OwlViTVisionConfig
|
||||
|
||||
## OwlViTFeatureExtractor
|
||||
|
||||
[[autodoc]] OwlViTFeatureExtractor
|
||||
- __call__
|
||||
|
||||
## OwlViTProcessor
|
||||
|
||||
[[autodoc]] OwlViTProcessor
|
||||
|
||||
## OwlViTModel
|
||||
|
||||
[[autodoc]] OwlViTModel
|
||||
- forward
|
||||
- get_text_features
|
||||
- get_image_features
|
||||
|
||||
## OwlViTTextModel
|
||||
|
||||
[[autodoc]] OwlViTTextModel
|
||||
- forward
|
||||
|
||||
## OwlViTVisionModel
|
||||
|
||||
[[autodoc]] OwlViTVisionModel
|
||||
- forward
|
||||
|
||||
## OwlViTForObjectDetection
|
||||
|
||||
[[autodoc]] OwlViTForObjectDetection
|
||||
- forward
|
@ -273,6 +273,13 @@ _import_structure = {
|
||||
],
|
||||
"models.openai": ["OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OpenAIGPTConfig", "OpenAIGPTTokenizer"],
|
||||
"models.opt": ["OPTConfig"],
|
||||
"models.owlvit": [
|
||||
"OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"OwlViTConfig",
|
||||
"OwlViTProcessor",
|
||||
"OwlViTTextConfig",
|
||||
"OwlViTVisionConfig",
|
||||
],
|
||||
"models.pegasus": ["PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "PegasusConfig", "PegasusTokenizer"],
|
||||
"models.perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverTokenizer"],
|
||||
"models.phobert": ["PhobertTokenizer"],
|
||||
@ -641,6 +648,7 @@ else:
|
||||
_import_structure["models.levit"].append("LevitFeatureExtractor")
|
||||
_import_structure["models.maskformer"].append("MaskFormerFeatureExtractor")
|
||||
_import_structure["models.mobilevit"].append("MobileViTFeatureExtractor")
|
||||
_import_structure["models.owlvit"].append("OwlViTFeatureExtractor")
|
||||
_import_structure["models.perceiver"].append("PerceiverFeatureExtractor")
|
||||
_import_structure["models.poolformer"].append("PoolFormerFeatureExtractor")
|
||||
_import_structure["models.segformer"].append("SegformerFeatureExtractor")
|
||||
@ -1507,6 +1515,16 @@ else:
|
||||
"OPTForSequenceClassification",
|
||||
]
|
||||
)
|
||||
_import_structure["models.owlvit"].extend(
|
||||
[
|
||||
"OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"OwlViTModel",
|
||||
"OwlViTPreTrainedModel",
|
||||
"OwlViTTextModel",
|
||||
"OwlViTVisionModel",
|
||||
"OwlViTForObjectDetection",
|
||||
]
|
||||
)
|
||||
_import_structure["models.pegasus"].extend(
|
||||
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]
|
||||
)
|
||||
@ -3012,6 +3030,13 @@ if TYPE_CHECKING:
|
||||
from .models.nystromformer import NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, NystromformerConfig
|
||||
from .models.openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, OpenAIGPTTokenizer
|
||||
from .models.opt import OPTConfig
|
||||
from .models.owlvit import (
|
||||
OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OwlViTConfig,
|
||||
OwlViTProcessor,
|
||||
OwlViTTextConfig,
|
||||
OwlViTVisionConfig,
|
||||
)
|
||||
from .models.pegasus import PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, PegasusConfig, PegasusTokenizer
|
||||
from .models.perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverTokenizer
|
||||
from .models.phobert import PhobertTokenizer
|
||||
@ -3328,6 +3353,7 @@ if TYPE_CHECKING:
|
||||
from .models.levit import LevitFeatureExtractor
|
||||
from .models.maskformer import MaskFormerFeatureExtractor
|
||||
from .models.mobilevit import MobileViTFeatureExtractor
|
||||
from .models.owlvit import OwlViTFeatureExtractor
|
||||
from .models.perceiver import PerceiverFeatureExtractor
|
||||
from .models.poolformer import PoolFormerFeatureExtractor
|
||||
from .models.segformer import SegformerFeatureExtractor
|
||||
@ -4044,6 +4070,14 @@ if TYPE_CHECKING:
|
||||
OPTModel,
|
||||
OPTPreTrainedModel,
|
||||
)
|
||||
from .models.owlvit import (
|
||||
OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
OwlViTForObjectDetection,
|
||||
OwlViTModel,
|
||||
OwlViTPreTrainedModel,
|
||||
OwlViTTextModel,
|
||||
OwlViTVisionModel,
|
||||
)
|
||||
from .models.pegasus import (
|
||||
PegasusForCausalLM,
|
||||
PegasusForConditionalGeneration,
|
||||
|
@ -101,6 +101,7 @@ from . import (
|
||||
nystromformer,
|
||||
openai,
|
||||
opt,
|
||||
owlvit,
|
||||
pegasus,
|
||||
perceiver,
|
||||
phobert,
|
||||
|
@ -98,6 +98,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("nystromformer", "NystromformerConfig"),
|
||||
("openai-gpt", "OpenAIGPTConfig"),
|
||||
("opt", "OPTConfig"),
|
||||
("owlvit", "OwlViTConfig"),
|
||||
("pegasus", "PegasusConfig"),
|
||||
("perceiver", "PerceiverConfig"),
|
||||
("plbart", "PLBartConfig"),
|
||||
@ -216,6 +217,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("openai-gpt", "OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("opt", "OPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("owlvit", "OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("perceiver", "PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("plbart", "PLBART_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
@ -346,6 +348,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("nystromformer", "Nyströmformer"),
|
||||
("openai-gpt", "OpenAI GPT"),
|
||||
("opt", "OPT"),
|
||||
("owlvit", "OWL-ViT"),
|
||||
("pegasus", "Pegasus"),
|
||||
("perceiver", "Perceiver"),
|
||||
("phobert", "PhoBERT"),
|
||||
|
@ -58,6 +58,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("maskformer", "MaskFormerFeatureExtractor"),
|
||||
("mctct", "MCTCTFeatureExtractor"),
|
||||
("mobilevit", "MobileViTFeatureExtractor"),
|
||||
("owlvit", "OwlViTFeatureExtractor"),
|
||||
("perceiver", "PerceiverFeatureExtractor"),
|
||||
("poolformer", "PoolFormerFeatureExtractor"),
|
||||
("regnet", "ConvNextFeatureExtractor"),
|
||||
|
@ -98,6 +98,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("nystromformer", "NystromformerModel"),
|
||||
("openai-gpt", "OpenAIGPTModel"),
|
||||
("opt", "OPTModel"),
|
||||
("owlvit", "OwlViTModel"),
|
||||
("pegasus", "PegasusModel"),
|
||||
("perceiver", "PerceiverModel"),
|
||||
("plbart", "PLBartModel"),
|
||||
|
@ -43,6 +43,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("layoutlmv2", "LayoutLMv2Processor"),
|
||||
("layoutlmv3", "LayoutLMv3Processor"),
|
||||
("layoutxlm", "LayoutXLMProcessor"),
|
||||
("owlvit", "OwlViTProcessor"),
|
||||
("sew", "Wav2Vec2Processor"),
|
||||
("sew-d", "Wav2Vec2Processor"),
|
||||
("speech_to_text", "Speech2TextProcessor"),
|
||||
|
@ -193,6 +193,7 @@ else:
|
||||
),
|
||||
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("opt", ("GPT2Tokenizer", None)),
|
||||
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"pegasus",
|
||||
(
|
||||
|
@ -199,6 +199,7 @@ class CLIPVisionConfig(PretrainedConfig):
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
num_channels=3,
|
||||
image_size=224,
|
||||
patch_size=32,
|
||||
hidden_act="quick_gelu",
|
||||
@ -216,6 +217,7 @@ class CLIPVisionConfig(PretrainedConfig):
|
||||
self.dropout = dropout
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.initializer_range = initializer_range
|
||||
|
100
src/transformers/models/owlvit/__init__.py
Normal file
100
src/transformers/models/owlvit/__init__.py
Normal file
@ -0,0 +1,100 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_owlvit": [
|
||||
"OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"OwlViTConfig",
|
||||
"OwlViTTextConfig",
|
||||
"OwlViTVisionConfig",
|
||||
],
|
||||
"processing_owlvit": ["OwlViTProcessor"],
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
if not is_vision_available() or not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["feature_extraction_owlvit"] = ["OwlViTFeatureExtractor"]
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_owlvit"] = [
|
||||
"OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"OwlViTModel",
|
||||
"OwlViTPreTrainedModel",
|
||||
"OwlViTTextModel",
|
||||
"OwlViTVisionModel",
|
||||
"OwlViTForObjectDetection",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_owlvit import (
|
||||
OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OwlViTConfig,
|
||||
OwlViTTextConfig,
|
||||
OwlViTVisionConfig,
|
||||
)
|
||||
from .processing_owlvit import OwlViTProcessor
|
||||
|
||||
try:
|
||||
if not is_vision_available() or not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .feature_extraction_owlvit import OwlViTFeatureExtractor
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_owlvit import (
|
||||
OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
OwlViTForObjectDetection,
|
||||
OwlViTModel,
|
||||
OwlViTPreTrainedModel,
|
||||
OwlViTTextModel,
|
||||
OwlViTVisionModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
336
src/transformers/models/owlvit/configuration_owlvit.py
Normal file
336
src/transformers/models/owlvit/configuration_owlvit.py
Normal file
@ -0,0 +1,336 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" OWL-ViT model configuration"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
from typing import Dict, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
OWLVIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"google/owlvit-base-patch32": "https://huggingface.co/google/owlvit-base-patch32/resolve/main/config.json",
|
||||
"google/owlvit-base-patch16": "https://huggingface.co/google/owlvit-base-patch16/resolve/main/config.json",
|
||||
"google/owlvit-large-patch14": "https://huggingface.co/google/owlvit-large-patch14/resolve/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class OwlViTTextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of an [`OwlViTTextModel`]. It is used to instantiate an
|
||||
OwlViT text encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the OwlViT
|
||||
[google/owlvit-base-patch32](https://huggingface.co/google/owlvit-base-patch32) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 49408):
|
||||
Vocabulary size of the OWL-ViT text model. Defines the number of different tokens that can be represented
|
||||
by the `inputs_ids` passed when calling [`OwlViTTextModel`].
|
||||
hidden_size (`int`, *optional*, defaults to 512):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 2048):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 16):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*,
|
||||
defaults to 1e-5): The epsilon used by the layer normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
initializer_factor (`float`, *optional*, defaults to 1):
|
||||
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
||||
testing).
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import OwlViTTextConfig, OwlViTTextModel
|
||||
|
||||
>>> # Initializing a OwlViTTextModel with google/owlvit-base-patch32 style configuration
|
||||
>>> configuration = OwlViTTextConfig()
|
||||
|
||||
>>> # Initializing a OwlViTTextConfig from the google/owlvit-base-patch32 style configuration
|
||||
>>> model = OwlViTTextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
model_type = "owlvit_text_model"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=49408,
|
||||
hidden_size=512,
|
||||
intermediate_size=2048,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=8,
|
||||
max_position_embeddings=16,
|
||||
hidden_act="quick_gelu",
|
||||
layer_norm_eps=0.00001,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=0,
|
||||
bos_token_id=49406,
|
||||
eos_token_id=49407,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_act = hidden_act
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.initializer_factor = initializer_factor
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# get the text config dict if we are loading from OwlViTConfig
|
||||
if config_dict.get("model_type") == "owlvit":
|
||||
config_dict = config_dict["text_config"]
|
||||
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class OwlViTVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of an [`OwlViTVisionModel`]. It is used to instantiate
|
||||
an OWL-ViT image encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the OWL-ViT
|
||||
[google/owlvit-base-patch32](https://huggingface.co/google/owlvit-base-patch32) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
image_size (`int`, *optional*, defaults to 768):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to 32):
|
||||
The size (resolution) of each patch.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. layer_norm_eps (`float`, *optional*,
|
||||
defaults to 1e-5): The epsilon used by the layer normalization layers.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
initializer_factor (`float``, *optional*, defaults to 1):
|
||||
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
||||
testing).
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import OwlViTVisionConfig, OwlViTVisionModel
|
||||
|
||||
>>> # Initializing a OwlViTVisionModel with google/owlvit-base-patch32 style configuration
|
||||
>>> configuration = OwlViTVisionConfig()
|
||||
|
||||
>>> # Initializing a OwlViTVisionModel model from the google/owlvit-base-patch32 style configuration
|
||||
>>> model = OwlViTVisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "owlvit_vision_model"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
image_size=768,
|
||||
patch_size=32,
|
||||
hidden_act="quick_gelu",
|
||||
layer_norm_eps=0.00001,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.hidden_act = hidden_act
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.initializer_factor = initializer_factor
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# get the vision config dict if we are loading from OwlViTConfig
|
||||
if config_dict.get("model_type") == "owlvit":
|
||||
config_dict = config_dict["vision_config"]
|
||||
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class OwlViTConfig(PretrainedConfig):
|
||||
r"""
|
||||
[`OwlViTConfig`] is the configuration class to store the configuration of an [`OwlViTModel`]. It is used to
|
||||
instantiate an OWL-ViT model according to the specified arguments, defining the text model and vision model
|
||||
configs.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
text_config_dict (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`OwlViTTextConfig`].
|
||||
vision_config_dict (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`OwlViTVisionConfig`].
|
||||
projection_dim (`int`, *optional*, defaults to 512):
|
||||
Dimensionality of text and vision projection layers.
|
||||
logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
|
||||
The inital value of the *logit_scale* parameter. Default is used as per the original OWL-ViT
|
||||
implementation.
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments.
|
||||
"""
|
||||
|
||||
model_type = "owlvit"
|
||||
is_composition = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
projection_dim=512,
|
||||
logit_scale_init_value=2.6592,
|
||||
return_dict=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(text_config=text_config, vision_config=vision_config, **kwargs)
|
||||
|
||||
if text_config is None:
|
||||
text_config = {}
|
||||
logger.info("text_config_dict is None. Initializing the OwlViTTextConfig with default values.")
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {}
|
||||
logger.info("vision_config_dict is None. initializing the OwlViTVisionConfig with default values.")
|
||||
|
||||
self.text_config = OwlViTTextConfig(**text_config)
|
||||
self.vision_config = OwlViTVisionConfig(**vision_config)
|
||||
|
||||
self.projection_dim = projection_dim
|
||||
self.logit_scale_init_value = logit_scale_init_value
|
||||
self.return_dict = return_dict
|
||||
self.initializer_factor = 1.0
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_text_vision_configs(cls, text_config: Dict, vision_config: Dict, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`OwlViTConfig`] (or a derived class) from owlvit text model configuration and owlvit vision
|
||||
model configuration.
|
||||
|
||||
Returns:
|
||||
[`OwlViTConfig`]: An instance of a configuration object
|
||||
"""
|
||||
config_dict = {}
|
||||
config_dict["text_config"] = text_config
|
||||
config_dict["vision_config"] = vision_config
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["text_config"] = self.text_config.to_dict()
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
@ -0,0 +1,407 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert OWL-ViT checkpoints from the original repository. URL:
|
||||
https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit"""
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from clip.model import CLIP
|
||||
from flax.training import checkpoints
|
||||
from huggingface_hub import Repository
|
||||
from transformers import (
|
||||
CLIPTokenizer,
|
||||
OwlViTConfig,
|
||||
OwlViTFeatureExtractor,
|
||||
OwlViTForObjectDetection,
|
||||
OwlViTModel,
|
||||
OwlViTProcessor,
|
||||
)
|
||||
|
||||
|
||||
CONFIGS = {
|
||||
"vit_b32": dict(
|
||||
embed_dim=512,
|
||||
image_resolution=768,
|
||||
context_length=16,
|
||||
vocab_size=49408,
|
||||
vision_layers=12,
|
||||
vision_width=768,
|
||||
vision_patch_size=32,
|
||||
transformer_width=512,
|
||||
transformer_heads=8,
|
||||
transformer_layers=12,
|
||||
),
|
||||
"vit_b16": dict(
|
||||
embed_dim=512,
|
||||
image_resolution=768,
|
||||
context_length=16,
|
||||
vocab_size=49408,
|
||||
vision_layers=12,
|
||||
vision_width=768,
|
||||
vision_patch_size=16,
|
||||
transformer_width=512,
|
||||
transformer_heads=8,
|
||||
transformer_layers=12,
|
||||
),
|
||||
"vit_l14": dict(
|
||||
embed_dim=768,
|
||||
image_resolution=840,
|
||||
context_length=16,
|
||||
vocab_size=49408,
|
||||
vision_layers=24,
|
||||
vision_width=1024,
|
||||
vision_patch_size=14,
|
||||
transformer_width=768,
|
||||
transformer_heads=12,
|
||||
transformer_layers=12,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def flatten_nested_dict(params, parent_key="", sep="/"):
|
||||
items = []
|
||||
|
||||
for k, v in params.items():
|
||||
new_key = parent_key + sep + k if parent_key else k
|
||||
|
||||
if isinstance(v, collections.MutableMapping):
|
||||
items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
|
||||
def to_f32(params):
|
||||
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, params)
|
||||
|
||||
|
||||
def copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||
q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0)
|
||||
q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0)
|
||||
|
||||
out_proj_weights = pt_attn_layer.out_proj.weight
|
||||
out_proj_bias = pt_attn_layer.out_proj.bias
|
||||
|
||||
hf_attn_layer.q_proj.weight.data = q_proj
|
||||
hf_attn_layer.q_proj.bias.data = q_proj_bias
|
||||
|
||||
hf_attn_layer.k_proj.weight.data = k_proj
|
||||
hf_attn_layer.k_proj.bias.data = k_proj_bias
|
||||
|
||||
hf_attn_layer.v_proj.weight.data = v_proj
|
||||
hf_attn_layer.v_proj.bias.data = v_proj_bias
|
||||
|
||||
hf_attn_layer.out_proj.weight = out_proj_weights
|
||||
hf_attn_layer.out_proj.bias = out_proj_bias
|
||||
|
||||
|
||||
def copy_mlp(hf_mlp, pt_mlp):
|
||||
copy_linear(hf_mlp.fc1, pt_mlp.c_fc)
|
||||
copy_linear(hf_mlp.fc2, pt_mlp.c_proj)
|
||||
|
||||
|
||||
def copy_linear(hf_linear, pt_linear):
|
||||
hf_linear.weight = pt_linear.weight
|
||||
hf_linear.bias = pt_linear.bias
|
||||
|
||||
|
||||
def copy_layer(hf_layer, pt_layer):
|
||||
# copy layer norms
|
||||
copy_linear(hf_layer.layer_norm1, pt_layer.ln_1)
|
||||
copy_linear(hf_layer.layer_norm2, pt_layer.ln_2)
|
||||
|
||||
# copy MLP
|
||||
copy_mlp(hf_layer.mlp, pt_layer.mlp)
|
||||
|
||||
# copy attn
|
||||
copy_attn_layer(hf_layer.self_attn, pt_layer.attn)
|
||||
|
||||
|
||||
def copy_layers(hf_layers, pt_layers):
|
||||
for hf_layer, pt_layer in zip(hf_layers, pt_layers):
|
||||
copy_layer(hf_layer, pt_layer)
|
||||
|
||||
|
||||
def copy_encoder(hf_encoder, pt_model):
|
||||
# copy embeds
|
||||
hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight
|
||||
hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding
|
||||
|
||||
# copy layer norm
|
||||
copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final)
|
||||
|
||||
# copy hidden layers
|
||||
copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks)
|
||||
|
||||
|
||||
def copy_text_model_and_projection(hf_model, pt_model):
|
||||
# copy projection
|
||||
hf_model.text_projection.weight.data = pt_model.text_projection.data.T
|
||||
|
||||
# copy text encoder
|
||||
copy_encoder(hf_model.text_model, pt_model)
|
||||
|
||||
|
||||
def copy_vision_model_and_projection(hf_model, pt_model):
|
||||
# copy projection
|
||||
hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T
|
||||
|
||||
# copy layer norms
|
||||
copy_linear(hf_model.vision_model.pre_layernorm, pt_model.visual.ln_pre)
|
||||
copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post)
|
||||
|
||||
# copy embeds
|
||||
hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data
|
||||
hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding
|
||||
hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data
|
||||
|
||||
# copy encoder
|
||||
copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks)
|
||||
|
||||
|
||||
def copy_class_merge_token(hf_model, flax_params):
|
||||
flax_class_token_params = flatten_nested_dict(flax_params["backbone"]["merged_class_token"])
|
||||
|
||||
weight = torch.from_numpy(flax_class_token_params["scale"])
|
||||
bias = torch.from_numpy(flax_class_token_params["bias"])
|
||||
hf_model.layer_norm.weight = nn.Parameter(weight)
|
||||
hf_model.layer_norm.bias = nn.Parameter(bias)
|
||||
|
||||
|
||||
def copy_class_box_heads(hf_model, flax_params):
|
||||
pt_params = hf_model.state_dict()
|
||||
new_params = {}
|
||||
|
||||
# Rename class prediction head flax params to pytorch HF
|
||||
flax_class_params = flatten_nested_dict(flax_params["class_head"])
|
||||
|
||||
for flax_key, v in flax_class_params.items():
|
||||
torch_key = flax_key.replace("/", ".")
|
||||
torch_key = torch_key.replace(".kernel", ".weight")
|
||||
torch_key = torch_key.replace("Dense_0", "dense0")
|
||||
torch_key = "class_head." + torch_key
|
||||
|
||||
if "weight" in torch_key and v.ndim == 2:
|
||||
v = v.T
|
||||
|
||||
new_params[torch_key] = nn.Parameter(torch.from_numpy(v))
|
||||
|
||||
# Rename box prediction box flax params to pytorch HF
|
||||
flax_box_params = flatten_nested_dict(flax_params["obj_box_head"])
|
||||
|
||||
for flax_key, v in flax_box_params.items():
|
||||
torch_key = flax_key.replace("/", ".")
|
||||
torch_key = torch_key.replace(".kernel", ".weight")
|
||||
torch_key = torch_key.replace("_", "").lower()
|
||||
torch_key = "box_head." + torch_key
|
||||
|
||||
if "weight" in torch_key and v.ndim == 2:
|
||||
v = v.T
|
||||
|
||||
new_params[torch_key] = nn.Parameter(torch.from_numpy(v))
|
||||
|
||||
# Copy flax params to PyTorch params
|
||||
for name, param in new_params.items():
|
||||
if name in pt_params.keys():
|
||||
pt_params[name].copy_(param)
|
||||
|
||||
|
||||
def copy_flax_attn_params(hf_backbone, flax_attn_params):
|
||||
for k, v in flax_attn_params.items():
|
||||
if k.startswith("transformer"):
|
||||
torch_key = k.replace("transformer.resblocks", "text_model.encoder.layers")
|
||||
else:
|
||||
torch_key = k.replace("visual.transformer.resblocks", "vision_model.encoder.layers")
|
||||
|
||||
torch_key = torch_key.replace("attn", "self_attn")
|
||||
torch_key = torch_key.replace("key", "k_proj")
|
||||
torch_key = torch_key.replace("value", "v_proj")
|
||||
torch_key = torch_key.replace("query", "q_proj")
|
||||
torch_key = torch_key.replace("out", "out_proj")
|
||||
|
||||
if "bias" in torch_key and v.ndim == 2:
|
||||
shape = v.shape[0] * v.shape[1]
|
||||
v = v.reshape(shape)
|
||||
|
||||
if "weight" in torch_key and "out" in torch_key:
|
||||
shape = (v.shape[0] * v.shape[1], v.shape[2])
|
||||
v = v.reshape(shape).T
|
||||
|
||||
if "weight" in torch_key and "out" not in torch_key:
|
||||
shape = (v.shape[0], v.shape[1] * v.shape[2])
|
||||
v = v.reshape(shape).T
|
||||
|
||||
# Copy flax CLIP attn params to HF PyTorch params
|
||||
v = torch.from_numpy(v)
|
||||
hf_backbone.state_dict()[torch_key].copy_(v)
|
||||
|
||||
|
||||
def _convert_attn_layers(params):
|
||||
new_params = {}
|
||||
processed_attn_layers = []
|
||||
|
||||
for k, v in params.items():
|
||||
if "attn." in k:
|
||||
base = k[: k.rindex("attn.") + 5]
|
||||
if base in processed_attn_layers:
|
||||
continue
|
||||
|
||||
processed_attn_layers.append(base)
|
||||
dim = params[base + "out.weight"].shape[-1]
|
||||
new_params[base + "out_proj.weight"] = params[base + "out.weight"].reshape(dim, dim).T
|
||||
new_params[base + "out_proj.bias"] = params[base + "out.bias"]
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
|
||||
def convert_clip_backbone(flax_params, torch_config):
|
||||
torch_model = CLIP(**torch_config)
|
||||
torch_model.eval()
|
||||
torch_clip_params = torch_model.state_dict()
|
||||
|
||||
flax_clip_params = flatten_nested_dict(flax_params["backbone"]["clip"])
|
||||
new_torch_params = {}
|
||||
|
||||
for flax_key, v in flax_clip_params.items():
|
||||
torch_key = flax_key.replace("/", ".")
|
||||
torch_key = torch_key.replace("text.token_embedding.embedding", "token_embedding.kernel")
|
||||
|
||||
if (
|
||||
torch_key.startswith("text.transformer")
|
||||
or torch_key.startswith("text.text_projection")
|
||||
or torch_key.startswith("text.ln_final")
|
||||
or torch_key.startswith("text.positional_embedding")
|
||||
):
|
||||
torch_key = torch_key[5:]
|
||||
|
||||
torch_key = torch_key.replace("text_projection.kernel", "text_projection")
|
||||
torch_key = torch_key.replace("visual.proj.kernel", "visual.proj")
|
||||
torch_key = torch_key.replace(".scale", ".weight")
|
||||
torch_key = torch_key.replace(".kernel", ".weight")
|
||||
|
||||
if "conv" in torch_key or "downsample.0.weight" in torch_key:
|
||||
v = v.transpose(3, 2, 0, 1)
|
||||
|
||||
elif "weight" in torch_key and v.ndim == 2 and "embedding" not in torch_key:
|
||||
# Fully connected layers are transposed, embeddings are not
|
||||
v = v.T
|
||||
|
||||
new_torch_params[torch_key] = v
|
||||
|
||||
attn_params = _convert_attn_layers(new_torch_params)
|
||||
new_torch_params.update(attn_params)
|
||||
attn_params = {}
|
||||
|
||||
# Copy flax CLIP backbone params to PyTorch params
|
||||
for name, param in new_torch_params.items():
|
||||
if name in torch_clip_params.keys():
|
||||
|
||||
new_param = torch.from_numpy(new_torch_params[name])
|
||||
torch_clip_params[name].copy_(new_param)
|
||||
else:
|
||||
attn_params[name] = param
|
||||
|
||||
return torch_clip_params, torch_model, attn_params
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_owlvit_checkpoint(pt_backbone, flax_params, attn_params, pytorch_dump_folder_path, config_path=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
repo = Repository(pytorch_dump_folder_path, clone_from=f"google/{pytorch_dump_folder_path}")
|
||||
repo.git_pull()
|
||||
|
||||
if config_path is not None:
|
||||
config = OwlViTConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = OwlViTConfig()
|
||||
|
||||
hf_backbone = OwlViTModel(config).eval()
|
||||
hf_model = OwlViTForObjectDetection(config).eval()
|
||||
|
||||
copy_text_model_and_projection(hf_backbone, pt_backbone)
|
||||
copy_vision_model_and_projection(hf_backbone, pt_backbone)
|
||||
hf_backbone.logit_scale = pt_backbone.logit_scale
|
||||
copy_flax_attn_params(hf_backbone, attn_params)
|
||||
|
||||
hf_model.owlvit = hf_backbone
|
||||
copy_class_merge_token(hf_model, flax_params)
|
||||
copy_class_box_heads(hf_model, flax_params)
|
||||
|
||||
# Save HF model
|
||||
hf_model.save_pretrained(repo.local_dir)
|
||||
|
||||
# Initialize feature extractor
|
||||
feature_extractor = OwlViTFeatureExtractor(
|
||||
size=config.vision_config.image_size, crop_size=config.vision_config.image_size
|
||||
)
|
||||
# Initialize tokenizer
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32", pad_token="!", model_max_length=16)
|
||||
|
||||
# Initialize processor
|
||||
processor = OwlViTProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
feature_extractor.save_pretrained(repo.local_dir)
|
||||
processor.save_pretrained(repo.local_dir)
|
||||
|
||||
repo.git_add()
|
||||
repo.git_commit("Upload model and processor")
|
||||
repo.git_push()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--owlvit_version",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="OWL-ViT model name [clip_b16, clip_b32, clip_l14].",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--owlvit_checkpoint", default=None, type=str, required=True, help="Path to flax model checkpoint."
|
||||
)
|
||||
parser.add_argument("--hf_config", default=None, type=str, required=True, help="Path to HF model config.")
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default="hf_model", type=str, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize PyToch clip model
|
||||
model_name = args.owlvit_version
|
||||
if model_name == "clip_b16":
|
||||
torch_config = CONFIGS["vit_b16"]
|
||||
elif model_name == "clip_b32":
|
||||
torch_config = CONFIGS["vit_b32"]
|
||||
elif model_name == "clip_l14":
|
||||
torch_config = CONFIGS["vit_l14"]
|
||||
|
||||
# Load from checkpoint and convert params to float-32
|
||||
variables = checkpoints.restore_checkpoint(args.owlvit_checkpoint, target=None)["optimizer"]["target"]
|
||||
flax_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, variables)
|
||||
del variables
|
||||
|
||||
# Convert CLIP backbone
|
||||
pt_backbone_params, clip_pt, attn_params = convert_clip_backbone(flax_params, torch_config)
|
||||
|
||||
convert_owlvit_checkpoint(clip_pt, flax_params, attn_params, args.pytorch_dump_folder_path, args.hf_config)
|
210
src/transformers/models/owlvit/feature_extraction_owlvit.py
Normal file
210
src/transformers/models/owlvit/feature_extraction_owlvit.py
Normal file
@ -0,0 +1,210 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Feature extractor class for OwlViT."""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
|
||||
from ...utils import TensorType, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def center_to_corners_format(x):
|
||||
"""
|
||||
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
|
||||
(left, top, right, bottom).
|
||||
"""
|
||||
x_center, y_center, width, height = x.unbind(-1)
|
||||
boxes = [(x_center - 0.5 * width), (y_center - 0.5 * height), (x_center + 0.5 * width), (y_center + 0.5 * height)]
|
||||
return torch.stack(boxes, dim=-1)
|
||||
|
||||
|
||||
class OwlViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
r"""
|
||||
Constructs an OWL-ViT feature extractor.
|
||||
|
||||
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
|
||||
should refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the shorter edge of the input to a certain `size`.
|
||||
size (`int`, *optional*, defaults to 768):
|
||||
Resize the shorter edge of the input to the given size. Only has an effect if `do_resize` is set to `True`.
|
||||
resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`):
|
||||
An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
|
||||
`PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
|
||||
if `do_resize` is set to `True`.
|
||||
do_center_crop (`bool`, *optional*, defaults to `True`):
|
||||
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
|
||||
image is padded with 0's and then center cropped.
|
||||
crop_size (`int`, *optional*, defaults to 768):
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to normalize the input with `image_mean` and `image_std`. Desired output size when applying
|
||||
center-cropping. Only has an effect if `do_center_crop` is set to `True`.
|
||||
image_mean (`List[int]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
||||
The sequence of means for each channel, to be used when normalizing images.
|
||||
image_std (`List[int]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize=True,
|
||||
size=768,
|
||||
resample=Image.BICUBIC,
|
||||
crop_size=768,
|
||||
do_center_crop=True,
|
||||
do_normalize=True,
|
||||
image_mean=None,
|
||||
image_std=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.crop_size = crop_size
|
||||
self.do_resize = do_resize
|
||||
self.do_center_crop = do_center_crop
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
||||
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
def post_process(self, outputs, target_sizes):
|
||||
"""
|
||||
Converts the output of [`OwlViTForObjectDetection`] into the format expected by the COCO api.
|
||||
|
||||
Args:
|
||||
outputs ([`OwlViTObjectDetectionOutput`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
|
||||
Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
|
||||
image size (before any data augmentation). For visualization, this should be the image size after data
|
||||
augment, but before padding.
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
|
||||
in the batch as predicted by the model.
|
||||
"""
|
||||
out_logits, out_bbox = outputs.logits, outputs.pred_boxes
|
||||
|
||||
if len(out_logits) != len(target_sizes):
|
||||
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
|
||||
if target_sizes.shape[1] != 2:
|
||||
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
||||
|
||||
prob = nn.functional.softmax(out_logits, -1)
|
||||
scores, labels = prob[..., :-1].max(-1)
|
||||
|
||||
# Convert to [x0, y0, x1, y1] format
|
||||
boxes = center_to_corners_format(out_bbox)
|
||||
|
||||
# Convert from relative [0, 1] to absolute [0, height] coordinates
|
||||
img_h, img_w = target_sizes.unbind(1)
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
||||
boxes = boxes * scale_fct[:, None, :]
|
||||
|
||||
results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
|
||||
|
||||
return results
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: Union[
|
||||
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
|
||||
],
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several image(s).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
||||
PIL images.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W) or (H, W, C),
|
||||
where C is a number of channels, H and W are image height and width.
|
||||
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **pixel_values** -- Pixel values to be fed to a model.
|
||||
"""
|
||||
# Input type checking for clearer error
|
||||
valid_images = False
|
||||
|
||||
# Check that images has a valid type
|
||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
||||
valid_images = True
|
||||
elif isinstance(images, (list, tuple)):
|
||||
if isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
||||
valid_images = True
|
||||
|
||||
if not valid_images:
|
||||
raise ValueError(
|
||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(images, (list, tuple))
|
||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
||||
)
|
||||
|
||||
if not is_batched:
|
||||
images = [images]
|
||||
|
||||
# transformations (resizing + center cropping + normalization)
|
||||
if self.do_resize and self.size is not None and self.resample is not None:
|
||||
images = [
|
||||
self.resize(image=image, size=self.size, resample=self.resample, default_to_square=False)
|
||||
for image in images
|
||||
]
|
||||
if self.do_center_crop and self.crop_size is not None:
|
||||
images = [self.center_crop(image, self.crop_size) for image in images]
|
||||
if self.do_normalize:
|
||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
||||
|
||||
# return as BatchFeature
|
||||
data = {"pixel_values": images}
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
return encoded_inputs
|
1386
src/transformers/models/owlvit/modeling_owlvit.py
Normal file
1386
src/transformers/models/owlvit/modeling_owlvit.py
Normal file
File diff suppressed because it is too large
Load Diff
154
src/transformers/models/owlvit/processing_owlvit.py
Normal file
154
src/transformers/models/owlvit/processing_owlvit.py
Normal file
@ -0,0 +1,154 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Image/Text processor class for OWL-ViT
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_flax_available, is_tf_available, is_torch_available
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
|
||||
|
||||
class OwlViTProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs an OWL-ViT processor which wraps [`OwlViTFeatureExtractor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`]
|
||||
into a single processor that interits both the feature extractor and tokenizer functionalities. See the
|
||||
[`~OwlViTProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
feature_extractor ([`OwlViTFeatureExtractor`]):
|
||||
The feature extractor is a required input.
|
||||
tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
feature_extractor_class = "OwlViTFeatureExtractor"
|
||||
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
|
||||
|
||||
def __init__(self, feature_extractor, tokenizer):
|
||||
super().__init__(feature_extractor, tokenizer)
|
||||
|
||||
def __call__(self, text=None, images=None, padding="max_length", return_tensors="np", **kwargs):
|
||||
"""
|
||||
Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and
|
||||
`kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode:
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
CLIPFeatureExtractor's [`~CLIPFeatureExtractor.__call__`] if `images` is not `None`. Please refer to the
|
||||
doctsring of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
|
||||
`List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||
number of channels, H and W are image height and width.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
Returns:
|
||||
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
"""
|
||||
|
||||
if text is None and images is None:
|
||||
raise ValueError("You have to specify at least one text or image. Both cannot be none.")
|
||||
|
||||
if text is not None:
|
||||
if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)):
|
||||
encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)]
|
||||
|
||||
elif isinstance(text, List) and isinstance(text[0], List):
|
||||
encodings = []
|
||||
|
||||
# Maximum number of queries across batch
|
||||
max_num_queries = max([len(t) for t in text])
|
||||
|
||||
# Pad all batch samples to max number of text queries
|
||||
for t in text:
|
||||
if len(t) != max_num_queries:
|
||||
t = t + [" "] * (max_num_queries - len(t))
|
||||
|
||||
encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs)
|
||||
encodings.append(encoding)
|
||||
else:
|
||||
raise TypeError("Input text should be a string, a list of strings or a nested list of strings")
|
||||
|
||||
if return_tensors == "np":
|
||||
input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0)
|
||||
attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0)
|
||||
|
||||
elif return_tensors == "jax" and is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
input_ids = jnp.concatenate([encoding["input_ids"] for encoding in encodings], axis=0)
|
||||
attention_mask = jnp.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0)
|
||||
|
||||
elif return_tensors == "pt" and is_torch_available():
|
||||
import torch
|
||||
|
||||
input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0)
|
||||
attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0)
|
||||
|
||||
elif return_tensors == "tf" and is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0)
|
||||
attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0)
|
||||
|
||||
else:
|
||||
raise ValueError("Target return tensor type could not be returned")
|
||||
|
||||
encoding = BatchEncoding()
|
||||
encoding["input_ids"] = input_ids
|
||||
encoding["attention_mask"] = attention_mask
|
||||
|
||||
if images is not None:
|
||||
image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs)
|
||||
|
||||
if text is not None and images is not None:
|
||||
encoding["pixel_values"] = image_features.pixel_values
|
||||
return encoding
|
||||
elif text is not None:
|
||||
return encoding
|
||||
else:
|
||||
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
@ -3459,6 +3459,44 @@ class OPTPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class OwlViTForObjectDetection(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class OwlViTModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class OwlViTPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class OwlViTTextModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class OwlViTVisionModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PegasusForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -122,6 +122,13 @@ class MobileViTFeatureExtractor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class OwlViTFeatureExtractor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class PerceiverFeatureExtractor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
|
0
tests/models/owlvit/__init__.py
Normal file
0
tests/models/owlvit/__init__.py
Normal file
201
tests/models/owlvit/test_feature_extraction_owlvit.py
Normal file
201
tests/models/owlvit/test_feature_extraction_owlvit.py
Normal file
@ -0,0 +1,201 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import OwlViTFeatureExtractor
|
||||
|
||||
|
||||
class OwlViTFeatureExtractionTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=20,
|
||||
do_center_crop=True,
|
||||
crop_size=18,
|
||||
do_normalize=True,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_center_crop": self.do_center_crop,
|
||||
"crop_size": self.crop_size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class OwlViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
|
||||
|
||||
feature_extraction_class = OwlViTFeatureExtractor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.feature_extract_tester = OwlViTFeatureExtractionTester(self)
|
||||
|
||||
@property
|
||||
def feat_extract_dict(self):
|
||||
return self.feature_extract_tester.prepare_feat_extract_dict()
|
||||
|
||||
def test_feat_extract_properties(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
self.assertTrue(hasattr(feature_extractor, "do_resize"))
|
||||
self.assertTrue(hasattr(feature_extractor, "size"))
|
||||
self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
|
||||
self.assertTrue(hasattr(feature_extractor, "center_crop"))
|
||||
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
|
||||
self.assertTrue(hasattr(feature_extractor, "image_mean"))
|
||||
self.assertTrue(hasattr(feature_extractor, "image_std"))
|
||||
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random PIL images
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
|
||||
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
1,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.crop_size,
|
||||
self.feature_extract_tester.crop_size,
|
||||
),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.crop_size,
|
||||
self.feature_extract_tester.crop_size,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
1,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.crop_size,
|
||||
self.feature_extract_tester.crop_size,
|
||||
),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.crop_size,
|
||||
self.feature_extract_tester.crop_size,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
1,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.crop_size,
|
||||
self.feature_extract_tester.crop_size,
|
||||
),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.crop_size,
|
||||
self.feature_extract_tester.crop_size,
|
||||
),
|
||||
)
|
815
tests/models/owlvit/test_modeling_owlvit.py
Normal file
815
tests/models/owlvit/test_modeling_owlvit.py
Normal file
@ -0,0 +1,815 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch OwlViT model. """
|
||||
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import requests
|
||||
from transformers import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
_config_zero_init,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
random_attention_mask,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import OwlViTForObjectDetection, OwlViTModel, OwlViTTextModel, OwlViTVisionModel
|
||||
from transformers.models.owlvit.modeling_owlvit import OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import OwlViTProcessor
|
||||
|
||||
|
||||
class OwlViTVisionModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
image_size=32,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
initializer_range=0.02,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
|
||||
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = num_patches + 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return OwlViTVisionConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = OwlViTVisionModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
pixel_values = pixel_values.to(torch.float32)
|
||||
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values)
|
||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class OwlViTVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as OWLVIT does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (OwlViTVisionModel,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = OwlViTVisionModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=OwlViTVisionConfig, has_text_modality=False, hidden_size=37
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip(reason="OWLVIT does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="OWL-ViT does not support training yet")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OWL-ViT does not support training yet")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OwlViTVisionModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OwlViTVisionModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = OwlViTVisionModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class OwlViTTextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
num_queries=4,
|
||||
seq_length=16,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
max_position_embeddings=16,
|
||||
initializer_range=0.02,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_queries = num_queries
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size * self.num_queries, self.seq_length], self.vocab_size)
|
||||
input_mask = None
|
||||
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size * self.num_queries, self.seq_length])
|
||||
|
||||
if input_mask is not None:
|
||||
num_text, seq_length = input_mask.shape
|
||||
|
||||
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(num_text,))
|
||||
for idx, start_index in enumerate(rnd_start_indices):
|
||||
input_mask[idx, :start_index] = 1
|
||||
input_mask[idx, start_index:] = 0
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, input_mask
|
||||
|
||||
def get_config(self):
|
||||
return OwlViTTextConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, input_mask):
|
||||
model = OwlViTTextModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(input_ids=input_ids, attention_mask=input_mask)
|
||||
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size * self.num_queries, self.seq_length, self.hidden_size)
|
||||
)
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size * self.num_queries, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, input_mask = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class OwlViTTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (OwlViTTextModel,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = OwlViTTextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=OwlViTTextConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="OWL-ViT does not support training yet")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OWL-ViT does not support training yet")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OWLVIT does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OwlViTTextModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OwlViTTextModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = OwlViTTextModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class OwlViTModelTester:
|
||||
def __init__(self, parent, is_training=True):
|
||||
self.parent = parent
|
||||
self.text_model_tester = OwlViTTextModelTester(parent)
|
||||
self.vision_model_tester = OwlViTVisionModelTester(parent)
|
||||
self.is_training = is_training
|
||||
self.text_config = self.text_model_tester.get_config().to_dict()
|
||||
self.vision_config = self.vision_model_tester.get_config().to_dict()
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
config = self.get_config()
|
||||
return config, input_ids, attention_mask, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return OwlViTConfig.from_text_vision_configs(self.text_config, self.vision_config, projection_dim=64)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||
model = OwlViTModel(config).to(torch_device).eval()
|
||||
|
||||
with torch.no_grad():
|
||||
result = model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
image_logits_size = (
|
||||
self.vision_model_tester.batch_size,
|
||||
self.text_model_tester.batch_size * self.text_model_tester.num_queries,
|
||||
)
|
||||
text_logits_size = (
|
||||
self.text_model_tester.batch_size * self.text_model_tester.num_queries,
|
||||
self.vision_model_tester.batch_size,
|
||||
)
|
||||
self.parent.assertEqual(result.logits_per_image.shape, image_logits_size)
|
||||
self.parent.assertEqual(result.logits_per_text.shape, text_logits_size)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"return_loss": False,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class OwlViTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (OwlViTModel,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = OwlViTModelTester(self)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OwlViTModel does not have input/output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
# override as the `logit_scale` parameter initilization is different for OWLVIT
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
# check if `logit_scale` is initilized as per the original implementation
|
||||
if name == "logit_scale":
|
||||
self.assertAlmostEqual(
|
||||
param.data.item(),
|
||||
np.log(1 / 0.07),
|
||||
delta=1e-3,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
else:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
if not self.test_torchscript:
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.torchscript = True
|
||||
configs_no_init.return_dict = False
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
try:
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
pixel_values = inputs_dict["pixel_values"] # OWLVIT needs pixel_values
|
||||
traced_model = torch.jit.trace(model, (input_ids, pixel_values))
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||
|
||||
try:
|
||||
torch.jit.save(traced_model, pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't save module.")
|
||||
|
||||
try:
|
||||
loaded_model = torch.jit.load(pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't load module.")
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loaded_model.to(torch_device)
|
||||
loaded_model.eval()
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
loaded_model_state_dict = loaded_model.state_dict()
|
||||
|
||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||
|
||||
models_equal = True
|
||||
for layer_name, p1 in model_state_dict.items():
|
||||
p2 = loaded_model_state_dict[layer_name]
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_load_vision_text_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Save OwlViTConfig and check if we can load OwlViTVisionConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
vision_config = OwlViTVisionConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
||||
|
||||
# Save OwlViTConfig and check if we can load OwlViTTextConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
text_config = OwlViTTextConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = OwlViTModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class OwlViTForObjectDetectionTester:
|
||||
def __init__(self, parent, is_training=True):
|
||||
self.parent = parent
|
||||
self.text_model_tester = OwlViTTextModelTester(parent)
|
||||
self.vision_model_tester = OwlViTVisionModelTester(parent)
|
||||
self.is_training = is_training
|
||||
self.text_config = self.text_model_tester.get_config().to_dict()
|
||||
self.vision_config = self.vision_model_tester.get_config().to_dict()
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
config = self.get_config()
|
||||
return config, pixel_values, input_ids, attention_mask
|
||||
|
||||
def get_config(self):
|
||||
return OwlViTConfig.from_text_vision_configs(self.text_config, self.vision_config, projection_dim=64)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, input_ids, attention_mask):
|
||||
model = OwlViTForObjectDetection(config).to(torch_device).eval()
|
||||
with torch.no_grad():
|
||||
result = model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pred_boxes_size = (
|
||||
self.vision_model_tester.batch_size,
|
||||
(self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2,
|
||||
4,
|
||||
)
|
||||
pred_logits_size = (
|
||||
self.vision_model_tester.batch_size,
|
||||
(self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2,
|
||||
4,
|
||||
)
|
||||
pred_class_embeds_size = (
|
||||
self.vision_model_tester.batch_size,
|
||||
(self.vision_model_tester.image_size // self.vision_model_tester.patch_size) ** 2,
|
||||
self.text_model_tester.hidden_size,
|
||||
)
|
||||
self.parent.assertEqual(result.pred_boxes.shape, pred_boxes_size)
|
||||
self.parent.assertEqual(result.logits.shape, pred_logits_size)
|
||||
self.parent.assertEqual(result.class_embeds.shape, pred_class_embeds_size)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, input_ids, attention_mask = config_and_inputs
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (OwlViTForObjectDetection,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = OwlViTForObjectDetectionTester(self)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OwlViTModel does not have input/output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Test_initialization is tested in individual model tests")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Test_forward_signature is tested in individual model tests")
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Test_save_load_fast_init_from_base is tested in individual model tests")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OWL-ViT does not support training yet")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="OWL-ViT does not support training yet")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
if not self.test_torchscript:
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.torchscript = True
|
||||
configs_no_init.return_dict = False
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
try:
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
pixel_values = inputs_dict["pixel_values"] # OWLVIT needs pixel_values
|
||||
traced_model = torch.jit.trace(model, (input_ids, pixel_values))
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||
|
||||
try:
|
||||
torch.jit.save(traced_model, pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't save module.")
|
||||
|
||||
try:
|
||||
loaded_model = torch.jit.load(pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't load module.")
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loaded_model.to(torch_device)
|
||||
loaded_model.eval()
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
loaded_model_state_dict = loaded_model.state_dict()
|
||||
|
||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||
|
||||
models_equal = True
|
||||
for layer_name, p1 in model_state_dict.items():
|
||||
p2 = loaded_model_state_dict[layer_name]
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def set_nan_tensor_to_zero(t):
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
with torch.no_grad():
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(
|
||||
tuple_object.values(), dict_object.values()
|
||||
):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
recursive_check(tuple_output, dict_output)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = OwlViTForObjectDetection.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference(self):
|
||||
model_name = "google/owlvit-base-patch32"
|
||||
model = OwlViTModel.from_pretrained(model_name).to(torch_device)
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
|
||||
image = prepare_img()
|
||||
inputs = processor(
|
||||
text=[["a photo of a cat", "a photo of a dog"]],
|
||||
images=image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
self.assertEqual(
|
||||
outputs.logits_per_image.shape,
|
||||
torch.Size(
|
||||
(
|
||||
inputs.pixel_values.shape[0],
|
||||
inputs.input_ids.shape[0] * inputs.input_ids.shape[1] * inputs.pixel_values.shape[0],
|
||||
)
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs.logits_per_text.shape,
|
||||
torch.Size(
|
||||
(
|
||||
inputs.input_ids.shape[0] * inputs.input_ids.shape[1] * inputs.pixel_values.shape[0],
|
||||
inputs.pixel_values.shape[0],
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
expected_logits = torch.tensor([[1.0115, 0.9982]], device=torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_inference_object_detection(self):
|
||||
model_name = "google/owlvit-base-patch32"
|
||||
model = OwlViTForObjectDetection.from_pretrained(model_name).to(torch_device)
|
||||
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
|
||||
image = prepare_img()
|
||||
inputs = processor(
|
||||
text=[["a photo of a cat", "a photo of a dog"]],
|
||||
images=image,
|
||||
max_length=16,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
|
||||
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||
expected_slice_boxes = torch.tensor(
|
||||
[[0.0143, 0.0236, 0.0285], [0.0649, 0.0247, 0.0437], [0.0601, 0.0446, 0.0699]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
241
tests/models/owlvit/test_processor_owlvit.py
Normal file
241
tests/models/owlvit/test_processor_owlvit.py
Normal file
@ -0,0 +1,241 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import CLIPTokenizer, CLIPTokenizerFast
|
||||
from transformers.models.clip.tokenization_clip import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import FEATURE_EXTRACTOR_NAME, is_vision_available
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import OwlViTFeatureExtractor, OwlViTProcessor
|
||||
|
||||
|
||||
@require_vision
|
||||
class OwlViTProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
# fmt: off
|
||||
vocab = ["", "l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "lo", "l</w>", "w</w>", "r</w>", "t</w>", "low</w>", "er</w>", "lowest</w>", "newer</w>", "wider", "<unk>", "<|startoftext|>", "<|endoftext|>"]
|
||||
# fmt: on
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "l o", "lo w</w>", "e r</w>", ""]
|
||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
feature_extractor_map = {
|
||||
"do_resize": True,
|
||||
"size": 20,
|
||||
"do_center_crop": True,
|
||||
"crop_size": 18,
|
||||
"do_normalize": True,
|
||||
"image_mean": [0.48145466, 0.4578275, 0.40821073],
|
||||
"image_std": [0.26862954, 0.26130258, 0.27577711],
|
||||
}
|
||||
self.feature_extractor_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
|
||||
with open(self.feature_extractor_file, "w", encoding="utf-8") as fp:
|
||||
json.dump(feature_extractor_map, fp)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return CLIPTokenizer.from_pretrained(self.tmpdirname, pad_token="!", **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return CLIPTokenizerFast.from_pretrained(self.tmpdirname, pad_token="!", **kwargs)
|
||||
|
||||
def get_feature_extractor(self, **kwargs):
|
||||
return OwlViTFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def prepare_image_inputs(self):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
"""
|
||||
|
||||
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
|
||||
|
||||
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||
|
||||
return image_inputs
|
||||
|
||||
def test_save_load_pretrained_default(self):
|
||||
tokenizer_slow = self.get_tokenizer()
|
||||
tokenizer_fast = self.get_rust_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
processor_slow = OwlViTProcessor(tokenizer=tokenizer_slow, feature_extractor=feature_extractor)
|
||||
processor_slow.save_pretrained(self.tmpdirname)
|
||||
processor_slow = OwlViTProcessor.from_pretrained(self.tmpdirname, use_fast=False)
|
||||
|
||||
processor_fast = OwlViTProcessor(tokenizer=tokenizer_fast, feature_extractor=feature_extractor)
|
||||
processor_fast.save_pretrained(self.tmpdirname)
|
||||
processor_fast = OwlViTProcessor.from_pretrained(self.tmpdirname)
|
||||
|
||||
self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab())
|
||||
self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab())
|
||||
self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab())
|
||||
self.assertIsInstance(processor_slow.tokenizer, CLIPTokenizer)
|
||||
self.assertIsInstance(processor_fast.tokenizer, CLIPTokenizerFast)
|
||||
|
||||
self.assertEqual(processor_slow.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
||||
self.assertEqual(processor_fast.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
||||
self.assertIsInstance(processor_slow.feature_extractor, OwlViTFeatureExtractor)
|
||||
self.assertIsInstance(processor_fast.feature_extractor, OwlViTFeatureExtractor)
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = OwlViTProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
||||
feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False)
|
||||
|
||||
processor = OwlViTProcessor.from_pretrained(
|
||||
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False
|
||||
)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, CLIPTokenizerFast)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, OwlViTFeatureExtractor)
|
||||
|
||||
def test_feature_extractor(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
input_feat_extract = feature_extractor(image_input, return_tensors="np")
|
||||
input_processor = processor(images=image_input, return_tensors="np")
|
||||
|
||||
for key in input_feat_extract.keys():
|
||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||
|
||||
def test_tokenizer(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
input_str = "lower newer"
|
||||
|
||||
encoded_processor = processor(text=input_str, return_tensors="np")
|
||||
|
||||
encoded_tok = tokenizer(input_str, return_tensors="np")
|
||||
|
||||
for key in encoded_tok.keys():
|
||||
self.assertListEqual(encoded_tok[key][0].tolist(), encoded_processor[key][0].tolist())
|
||||
|
||||
def test_processor(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
|
||||
self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask", "pixel_values"])
|
||||
|
||||
# test if it raises when no input is passed
|
||||
with pytest.raises(ValueError):
|
||||
processor()
|
||||
|
||||
def test_processor_with_text_list(self):
|
||||
model_name = "google/owlvit-base-patch32"
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
|
||||
input_text = ["cat", "nasa badge"]
|
||||
inputs = processor(text=input_text)
|
||||
|
||||
seq_length = 16
|
||||
self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"])
|
||||
self.assertEqual(inputs["input_ids"].shape, (2, seq_length))
|
||||
|
||||
# test if it raises when no input is passed
|
||||
with pytest.raises(ValueError):
|
||||
processor()
|
||||
|
||||
def test_processor_with_nested_text_list(self):
|
||||
model_name = "google/owlvit-base-patch32"
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
|
||||
input_texts = [["cat", "nasa badge"], ["person"]]
|
||||
inputs = processor(text=input_texts)
|
||||
|
||||
seq_length = 16
|
||||
batch_size = len(input_texts)
|
||||
num_max_text_queries = max([len(texts) for texts in input_texts])
|
||||
|
||||
self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"])
|
||||
self.assertEqual(inputs["input_ids"].shape, (batch_size * num_max_text_queries, seq_length))
|
||||
|
||||
# test if it raises when no input is passed
|
||||
with pytest.raises(ValueError):
|
||||
processor()
|
||||
|
||||
def test_processor_case(self):
|
||||
model_name = "google/owlvit-base-patch32"
|
||||
processor = OwlViTProcessor.from_pretrained(model_name)
|
||||
|
||||
input_texts = ["cat", "nasa badge"]
|
||||
inputs = processor(text=input_texts)
|
||||
|
||||
seq_length = 16
|
||||
input_ids = inputs["input_ids"]
|
||||
predicted_ids = [
|
||||
[49406, 2368, 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[49406, 6841, 11301, 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
]
|
||||
|
||||
self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"])
|
||||
self.assertEqual(inputs["input_ids"].shape, (2, seq_length))
|
||||
self.assertListEqual(list(input_ids[0]), predicted_ids[0])
|
||||
self.assertListEqual(list(input_ids[1]), predicted_ids[1])
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
||||
|
||||
decoded_processor = processor.batch_decode(predicted_ids)
|
||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||
|
||||
self.assertListEqual(decoded_tok, decoded_processor)
|
@ -41,6 +41,7 @@ _re_checkpoint = re.compile("\[(.+?)\]\((https://huggingface\.co/.+?)\)")
|
||||
|
||||
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
||||
"CLIPConfig",
|
||||
"OwlViTConfig",
|
||||
"GroupViTConfig",
|
||||
"DecisionTransformerConfig",
|
||||
"EncoderDecoderConfig",
|
||||
|
@ -166,6 +166,9 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"LukeForEntityPairClassification",
|
||||
"LukeForEntitySpanClassification",
|
||||
"OpenAIGPTDoubleHeadsModel",
|
||||
"OwlViTTextModel",
|
||||
"OwlViTVisionModel",
|
||||
"OwlViTForObjectDetection",
|
||||
"RagModel",
|
||||
"RagSequenceForGeneration",
|
||||
"RagTokenForGeneration",
|
||||
|
Loading…
Reference in New Issue
Block a user