mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Add Vision Transformer and ViTFeatureExtractor (#10950)
* Squash all commits into one
* Update ViTFeatureExtractor to use image_utils instead of torchvision
* Remove torchvision and add Pillow
* Small docs improvement
* Address most comments by @sgugger
* Fix tests
* Clean up conversion script
* Pooler first draft
* Fix quality
* Improve conversion script
* Make style and quality
* Make fix-copies
* Minor docs improvements
* Should use fix-copies instead of manual handling
* Revert "Should use fix-copies instead of manual handling"
This reverts commit fd4e591bce
.
* Place ViT in alphabetical order
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
af6732225c
commit
30677dc743
@ -80,8 +80,8 @@ jobs:
|
||||
- v0.4-{{ checksum "setup.py" }}
|
||||
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,speech]
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
|
||||
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,speech,vision]
|
||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
|
||||
- save_cache:
|
||||
key: v0.4-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
@ -110,8 +110,8 @@ jobs:
|
||||
- v0.4-{{ checksum "setup.py" }}
|
||||
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,speech]
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
|
||||
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,speech,vision]
|
||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
|
||||
- save_cache:
|
||||
key: v0.4-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
@ -139,8 +139,8 @@ jobs:
|
||||
- v0.4-{{ checksum "setup.py" }}
|
||||
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[sklearn,torch,testing,sentencepiece,speech]
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
|
||||
- run: pip install .[sklearn,torch,testing,sentencepiece,speech,vision]
|
||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
|
||||
- save_cache:
|
||||
key: v0.4-torch-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
@ -223,8 +223,8 @@ jobs:
|
||||
- v0.4-{{ checksum "setup.py" }}
|
||||
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[sklearn,torch,testing,sentencepiece,speech]
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
|
||||
- run: pip install .[sklearn,torch,testing,sentencepiece,speech,vision]
|
||||
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
|
||||
- save_cache:
|
||||
key: v0.4-torch-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
|
@ -234,6 +234,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
1. **[Vision Transformer (ViT)](https://huggingface.co/transformers/model_doc/vit.html)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
1. **[Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
|
||||
1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
|
@ -210,22 +210,26 @@ and conversion utilities for the following models:
|
||||
43. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
||||
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
44. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
44. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
|
||||
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
|
||||
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
|
||||
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
45. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
|
||||
Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
45. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
46. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
||||
46. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
47. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
||||
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
47. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
48. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
||||
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
||||
Zettlemoyer and Veselin Stoyanov.
|
||||
48. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
49. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
49. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
50. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
|
||||
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
|
||||
@ -328,6 +332,8 @@ TensorFlow and/or Flax.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| ViT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Wav2Vec2 | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
@ -460,6 +466,7 @@ TensorFlow and/or Flax.
|
||||
model_doc/t5
|
||||
model_doc/tapas
|
||||
model_doc/transformerxl
|
||||
model_doc/vit
|
||||
model_doc/wav2vec2
|
||||
model_doc/xlm
|
||||
model_doc/xlmprophetnet
|
||||
|
102
docs/source/model_doc/vit.rst
Normal file
102
docs/source/model_doc/vit.rst
Normal file
@ -0,0 +1,102 @@
|
||||
..
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
Vision Transformer (ViT)
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. note::
|
||||
|
||||
This is a recently introduced model so the API hasn't been tested extensively. There may be some bugs or slight
|
||||
breaking changes to fix it in the future. If you see something strange, file a `Github Issue
|
||||
<https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__.
|
||||
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The Vision Transformer (ViT) model was proposed in `An Image is Worth 16x16 Words: Transformers for Image Recognition
|
||||
at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk
|
||||
Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob
|
||||
Uszkoreit, Neil Houlsby. It's the first paper that successfully trains a Transformer encoder on ImageNet, attaining
|
||||
very good results compared to familiar convolutional architectures.
|
||||
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*While the Transformer architecture has become the de-facto standard for natural language processing tasks, its
|
||||
applications to computer vision remain limited. In vision, attention is either applied in conjunction with
|
||||
convolutional networks, or used to replace certain components of convolutional networks while keeping their overall
|
||||
structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to
|
||||
sequences of image patches can perform very well on image classification tasks. When pre-trained on large amounts of
|
||||
data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.),
|
||||
Vision Transformer (ViT) attains excellent results compared to state-of-the-art convolutional networks while requiring
|
||||
substantially fewer computational resources to train.*
|
||||
|
||||
Tips:
|
||||
|
||||
- To feed images to the Transformer encoder, each image is split into a sequence of fixed-size non-overlapping patches,
|
||||
which are then linearly embedded. A [CLS] token is added to serve as representation of an entire image, which can be
|
||||
used for classification. The authors also add absolute position embeddings, and feed the resulting sequence of
|
||||
vectors to a standard Transformer encoder.
|
||||
- The Vision Transformer was pre-trained using a resolution of 224x224. During fine-tuning, it is often beneficial to
|
||||
use a higher resolution than pre-training `(Touvron et al., 2019) <https://arxiv.org/abs/1906.06423>`__, `(Kolesnikov
|
||||
et al., 2020) <https://arxiv.org/abs/1912.11370>`__. The authors report the best results with a resolution of 384x384
|
||||
during fine-tuning.
|
||||
- As the Vision Transformer expects each image to be of the same size (resolution), one can use
|
||||
:class:`~transformers.ViTFeatureExtractor` to resize (or rescale) and normalize images for the model.
|
||||
- Both the patch resolution and image resolution used during pre-training or fine-tuning are reflected in the name of
|
||||
each checkpoint. For example, :obj:`google/vit-base-patch16-224` refers to a base-sized architecture with patch
|
||||
resolution of 16x16 and fine-tuning resolution of 224x224. All checkpoints can be found on the `hub
|
||||
<https://huggingface.co/models?search=vit>`__.
|
||||
- The available checkpoints are either (1) pre-trained on `ImageNet-21k <http://www.image-net.org/>`__ (a collection of
|
||||
14 million images and 21k classes) only, or (2) also fine-tuned on `ImageNet
|
||||
<http://www.image-net.org/challenges/LSVRC/2012/>`__ (also referred to as ILSVRC 2012, a collection of 1.3 million
|
||||
images and 1,000 classes).
|
||||
- The best results are obtained with supervised pre-training, which is not the case in NLP. The authors also performed
|
||||
an experiment with a self-supervised pre-training objective, namely masked patched prediction (inspired by masked
|
||||
language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant
|
||||
improvement of 2% to training from scratch, but still 4% behind supervised pre-training.
|
||||
|
||||
|
||||
The original code (written in JAX) can be found `here <https://github.com/google-research/vision_transformer>`__.
|
||||
|
||||
Note that we converted the weights from Ross Wightman's `timm library
|
||||
<https://github.com/rwightman/pytorch-image-models>`__, who already converted the weights from JAX to PyTorch. Credits
|
||||
go to him!
|
||||
|
||||
|
||||
ViTConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ViTConfig
|
||||
:members:
|
||||
|
||||
|
||||
ViTFeatureExtractor
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ViTFeatureExtractor
|
||||
:members: __call__
|
||||
|
||||
|
||||
ViTModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ViTModel
|
||||
:members: forward
|
||||
|
||||
|
||||
ViTForImageClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ViTForImageClassification
|
||||
:members: forward
|
4
setup.py
4
setup.py
@ -107,6 +107,7 @@ _deps = [
|
||||
"onnxruntime>=1.4.0",
|
||||
"packaging",
|
||||
"parameterized",
|
||||
"Pillow",
|
||||
"protobuf",
|
||||
"psutil",
|
||||
"pydantic",
|
||||
@ -230,6 +231,7 @@ extras["sagemaker"] = deps_list("sagemaker")
|
||||
|
||||
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
|
||||
extras["speech"] = deps_list("soundfile", "torchaudio")
|
||||
extras["vision"] = deps_list("Pillow")
|
||||
|
||||
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
|
||||
extras["testing"] = (
|
||||
@ -242,7 +244,7 @@ extras["testing"] = (
|
||||
extras["docs"] = deps_list("recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme", "sphinx-copybutton")
|
||||
extras["quality"] = deps_list("black", "isort", "flake8")
|
||||
|
||||
extras["all"] = extras["tf"] + extras["torch"] + extras["flax"] + extras["sentencepiece"] + extras["tokenizers"]
|
||||
extras["all"] = extras["tf"] + extras["torch"] + extras["flax"] + extras["sentencepiece"] + extras["tokenizers"] + extras["speech"] + extras["vision"]
|
||||
|
||||
extras["dev"] = (
|
||||
extras["all"]
|
||||
|
@ -213,6 +213,7 @@ _import_structure = {
|
||||
"TransfoXLCorpus",
|
||||
"TransfoXLTokenizer",
|
||||
],
|
||||
"models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
|
||||
"models.wav2vec2": [
|
||||
"WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"Wav2Vec2Config",
|
||||
@ -299,7 +300,7 @@ else:
|
||||
name for name in dir(dummy_sentencepiece_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
# tokenziers-backed objects
|
||||
# tokenizers-backed objects
|
||||
if is_tokenizers_available():
|
||||
# Fast tokenizers
|
||||
_import_structure["models.convbert"].append("ConvBertTokenizerFast")
|
||||
@ -348,6 +349,7 @@ else:
|
||||
# Vision-specific objects
|
||||
if is_vision_available():
|
||||
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
|
||||
_import_structure["models.vit"].append("ViTFeatureExtractor")
|
||||
else:
|
||||
from .utils import dummy_vision_objects
|
||||
|
||||
@ -426,6 +428,7 @@ if is_torch_available():
|
||||
_import_structure["models.auto"].extend(
|
||||
[
|
||||
"MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
@ -867,6 +870,14 @@ if is_torch_available():
|
||||
"load_tf_weights_in_transfo_xl",
|
||||
]
|
||||
)
|
||||
_import_structure["models.vit"].extend(
|
||||
[
|
||||
"VIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"ViTForImageClassification",
|
||||
"ViTModel",
|
||||
"ViTPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.wav2vec2"].extend(
|
||||
[
|
||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -1311,7 +1322,6 @@ else:
|
||||
name for name in dir(dummy_flax_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
|
||||
# Direct imports for type-checking
|
||||
if TYPE_CHECKING:
|
||||
# Configuration
|
||||
@ -1479,6 +1489,7 @@ if TYPE_CHECKING:
|
||||
TransfoXLCorpus,
|
||||
TransfoXLTokenizer,
|
||||
)
|
||||
from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
from .models.wav2vec2 import (
|
||||
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
Wav2Vec2Config,
|
||||
@ -1601,6 +1612,7 @@ if TYPE_CHECKING:
|
||||
|
||||
if is_vision_available():
|
||||
from .image_utils import ImageFeatureExtractionMixin
|
||||
from .models.vit import ViTFeatureExtractor
|
||||
else:
|
||||
from .utils.dummy_vision_objects import *
|
||||
|
||||
@ -1666,6 +1678,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.auto import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
@ -2025,6 +2038,12 @@ if TYPE_CHECKING:
|
||||
TransfoXLPreTrainedModel,
|
||||
load_tf_weights_in_transfo_xl,
|
||||
)
|
||||
from .models.vit import (
|
||||
VIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
ViTForImageClassification,
|
||||
ViTModel,
|
||||
ViTPreTrainedModel,
|
||||
)
|
||||
from .models.wav2vec2 import (
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Wav2Vec2ForCTC,
|
||||
@ -2400,6 +2419,7 @@ if TYPE_CHECKING:
|
||||
# Import the same objects as dummies to get them in the namespace.
|
||||
# They will raise an import error if the user tries to instantiate / use them.
|
||||
from .utils.dummy_flax_objects import *
|
||||
|
||||
else:
|
||||
import importlib
|
||||
import os
|
||||
|
@ -24,6 +24,7 @@ deps = {
|
||||
"onnxruntime": "onnxruntime>=1.4.0",
|
||||
"packaging": "packaging",
|
||||
"parameterized": "parameterized",
|
||||
"Pillow": "Pillow",
|
||||
"protobuf": "protobuf",
|
||||
"psutil": "psutil",
|
||||
"pydantic": "pydantic",
|
||||
|
@ -175,10 +175,11 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_soundfile_available = False
|
||||
|
||||
_torchaudio_available = importlib.util.find_spec("torchaudio")
|
||||
|
||||
_torchaudio_available = importlib.util.find_spec("torchaudio") is not None
|
||||
try:
|
||||
_torchaudio_version = importlib_metadata.version("torchaudio")
|
||||
logger.debug(f"Successfully imported soundfile version {_torchaudio_version}")
|
||||
logger.debug(f"Successfully imported torchaudio version {_torchaudio_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torchaudio_available = False
|
||||
|
||||
|
@ -120,9 +120,9 @@ class ImageFeatureExtractionMixin:
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
if not isinstance(mean, np.ndarray):
|
||||
mean = np.array(mean)
|
||||
mean = np.array(mean).astype(image.dtype)
|
||||
if not isinstance(std, np.ndarray):
|
||||
std = np.array(std)
|
||||
std = np.array(std).astype(image.dtype)
|
||||
elif is_torch_tensor(image):
|
||||
import torch
|
||||
|
||||
|
@ -67,6 +67,7 @@ from . import (
|
||||
t5,
|
||||
tapas,
|
||||
transfo_xl,
|
||||
vit,
|
||||
wav2vec2,
|
||||
xlm,
|
||||
xlm_roberta,
|
||||
|
@ -29,6 +29,7 @@ _import_structure = {
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_auto"] = [
|
||||
"MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
@ -42,6 +43,7 @@ if is_torch_available():
|
||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"AutoModel",
|
||||
"AutoModelForCausalLM",
|
||||
"AutoModelForImageClassification",
|
||||
"AutoModelForMaskedLM",
|
||||
"AutoModelForMultipleChoice",
|
||||
"AutoModelForNextSentencePrediction",
|
||||
@ -90,6 +92,7 @@ if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
from .modeling_auto import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
@ -103,6 +106,7 @@ if TYPE_CHECKING:
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoModelForNextSentencePrediction,
|
||||
|
@ -68,6 +68,7 @@ from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFI
|
||||
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
|
||||
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
||||
from ..vit.configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
from ..wav2vec2.configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
|
||||
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
|
||||
from ..xlm_prophetnet.configuration_xlm_prophetnet import (
|
||||
@ -85,6 +86,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
VIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@ -134,6 +136,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
("gpt_neo", GPTNeoConfig),
|
||||
("big_bird", BigBirdConfig),
|
||||
("speech_to_text", Speech2TextConfig),
|
||||
("vit", ViTConfig),
|
||||
("wav2vec2", Wav2Vec2Config),
|
||||
("m2m_100", M2M100Config),
|
||||
("convbert", ConvBertConfig),
|
||||
@ -189,6 +192,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("gpt_neo", "GPT Neo"),
|
||||
("big_bird", "BigBird"),
|
||||
("speech_to_text", "Speech2Text"),
|
||||
("vit", "ViT"),
|
||||
("wav2vec2", "Wav2Vec2"),
|
||||
("m2m_100", "M2M100"),
|
||||
("convbert", "ConvBERT"),
|
||||
|
@ -237,6 +237,7 @@ from ..tapas.modeling_tapas import (
|
||||
TapasModel,
|
||||
)
|
||||
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
|
||||
from ..vit.modeling_vit import ViTForImageClassification, ViTModel
|
||||
from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2Model
|
||||
from ..xlm.modeling_xlm import (
|
||||
XLMForMultipleChoice,
|
||||
@ -313,6 +314,7 @@ from .configuration_auto import (
|
||||
T5Config,
|
||||
TapasConfig,
|
||||
TransfoXLConfig,
|
||||
ViTConfig,
|
||||
Wav2Vec2Config,
|
||||
XLMConfig,
|
||||
XLMProphetNetConfig,
|
||||
@ -331,6 +333,7 @@ MODEL_MAPPING = OrderedDict(
|
||||
(GPTNeoConfig, GPTNeoModel),
|
||||
(BigBirdConfig, BigBirdModel),
|
||||
(Speech2TextConfig, Speech2TextModel),
|
||||
(ViTConfig, ViTModel),
|
||||
(Wav2Vec2Config, Wav2Vec2Model),
|
||||
(M2M100Config, M2M100Model),
|
||||
(ConvBertConfig, ConvBertModel),
|
||||
@ -490,6 +493,13 @@ MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Image Classification mapping
|
||||
(ViTConfig, ViTForImageClassification),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
@ -1864,3 +1874,100 @@ class AutoModelForNextSentencePrediction:
|
||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys())}."
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForImageClassification:
|
||||
r"""
|
||||
This is a generic model class that will be instantiated as one of the model classes of the library---with an image
|
||||
classification head---when created with the :meth:`~transformers.AutoModelForImageClassification.from_pretrained`
|
||||
class method or the :meth:`~transformers.AutoModelForImageClassification.from_config` class method.
|
||||
|
||||
This class cannot be instantiated directly using ``__init__()`` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
"AutoModelForImageClassification is designed to be instantiated "
|
||||
"using the `AutoModelForImageClassification.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`AutoModelForImageClassification.from_config(config)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, use_model_types=False)
|
||||
def from_config(cls, config):
|
||||
r"""
|
||||
Instantiates one of the model classes of the library---with an image classification head---from a
|
||||
configuration.
|
||||
|
||||
Note:
|
||||
Loading a model from its configuration file does **not** load the model weights. It only affects the
|
||||
model's configuration. Use :meth:`~transformers.AutoModelForImageClassification.from_pretrained` to load
|
||||
the model weights.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The model class to instantiate is selected based on the configuration class:
|
||||
|
||||
List options
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import AutoConfig, AutoModelForImageClassification
|
||||
>>> # Download configuration from huggingface.co and cache.
|
||||
>>> config = AutoConfig.from_pretrained('google/vit_base_patch16_224')
|
||||
>>> model = AutoModelForImageClassification.from_config(config)
|
||||
"""
|
||||
if type(config) in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys():
|
||||
return MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING[type(config)](config)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
|
||||
@add_start_docstrings(
|
||||
"Instantiate one of the model classes of the library---with an image classification head---from a "
|
||||
"pretrained model.",
|
||||
AUTO_MODEL_PRETRAINED_DOCSTRING,
|
||||
)
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r"""
|
||||
Examples::
|
||||
|
||||
>>> from transformers import AutoConfig, AutoModelForImageClassification
|
||||
|
||||
>>> # Download model and configuration from huggingface.co and cache.
|
||||
>>> model = AutoModelForImageClassification.from_pretrained('google/vit_base_patch16_224')
|
||||
|
||||
>>> # Update configuration during loading
|
||||
>>> model = AutoModelForImageClassification.from_pretrained('google/vit_base_patch16_224', output_attentions=True)
|
||||
>>> model.config.output_attentions
|
||||
True
|
||||
|
||||
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
||||
>>> config = AutoConfig.from_json_file('./tf_model/vit_tf_model_config.json')
|
||||
>>> model = AutoModelForImageClassification.from_pretrained('./tf_model/vit_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
if type(config) in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys():
|
||||
return MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING[type(config)].from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
||||
)
|
||||
raise ValueError(
|
||||
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||
"Model type should be one of {}.".format(
|
||||
config.__class__,
|
||||
cls.__name__,
|
||||
", ".join(c.__name__ for c in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()),
|
||||
)
|
||||
)
|
||||
|
70
src/transformers/models/vit/__init__.py
Normal file
70
src/transformers/models/vit/__init__.py
Normal file
@ -0,0 +1,70 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
|
||||
}
|
||||
|
||||
if is_vision_available():
|
||||
_import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"]
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_vit"] = [
|
||||
"VIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"ViTForImageClassification",
|
||||
"ViTModel",
|
||||
"ViTPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
|
||||
if is_vision_available():
|
||||
from .feature_extraction_vit import ViTFeatureExtractor
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_vit import (
|
||||
VIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
ViTForImageClassification,
|
||||
ViTModel,
|
||||
ViTPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
class _LazyModule(_BaseLazyModule):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
"""
|
||||
|
||||
__file__ = globals()["__file__"]
|
||||
__path__ = [os.path.dirname(__file__)]
|
||||
|
||||
def _get_module(self, module_name: str):
|
||||
return importlib.import_module("." + module_name, self.__name__)
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
|
116
src/transformers/models/vit/configuration_vit.py
Normal file
116
src/transformers/models/vit/configuration_vit.py
Normal file
@ -0,0 +1,116 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 Google AI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" ViT model configuration """
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"nielsr/vit-base-patch16-224": "https://huggingface.co/vit-base-patch16-224/resolve/main/config.json",
|
||||
# See all ViT models at https://huggingface.co/models?filter=vit
|
||||
}
|
||||
|
||||
|
||||
class ViTConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.ViTModel`. It is used to
|
||||
instantiate an ViT model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the ViT `google/vit-base-patch16-224
|
||||
<https://huggingface.co/google/vit-base-patch16-224>`__ architecture.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
|
||||
Args:
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
image_size (:obj:`int`, `optional`, defaults to :obj:`224`):
|
||||
The size (resolution) of each image.
|
||||
patch_size (:obj:`int`, `optional`, defaults to :obj:`16`):
|
||||
The size (resolution) of each patch.
|
||||
num_channels (:obj:`int`, `optional`, defaults to :obj:`3`):
|
||||
The number of input channels.
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import ViTModel, ViTConfig
|
||||
|
||||
>>> # Initializing a ViT vit-base-patch16-224 style configuration
|
||||
>>> configuration = ViTConfig()
|
||||
|
||||
>>> # Initializing a model from the vit-base-patch16-224 style configuration
|
||||
>>> model = ViTModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
"""
|
||||
model_type = "vit"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
is_encoder_decoder=False,
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
num_channels=3,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
228
src/transformers/models/vit/convert_vit_timm_to_pytorch.py
Normal file
228
src/transformers/models/vit/convert_vit_timm_to_pytorch.py
Normal file
@ -0,0 +1,228 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ViT checkpoints from the timm library."""
|
||||
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import requests
|
||||
import timm
|
||||
from transformers import ViTConfig, ViTFeatureExtractor, ViTForImageClassification, ViTModel
|
||||
from transformers.utils import logging
|
||||
from transformers.utils.imagenet_classes import id2label
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config, base_model=False):
|
||||
rename_keys = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
|
||||
rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
|
||||
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
|
||||
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
|
||||
rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
|
||||
rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
|
||||
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
|
||||
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
|
||||
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
|
||||
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
|
||||
|
||||
# projection layer + position embeddings
|
||||
rename_keys.extend(
|
||||
[
|
||||
("cls_token", "vit.embeddings.cls_token"),
|
||||
("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
|
||||
("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
|
||||
("pos_embed", "vit.embeddings.position_embeddings"),
|
||||
]
|
||||
)
|
||||
|
||||
if base_model:
|
||||
# layernorm + pooler
|
||||
rename_keys.extend(
|
||||
[
|
||||
("norm.weight", "layernorm.weight"),
|
||||
("norm.bias", "layernorm.bias"),
|
||||
("pre_logits.fc.weight", "pooler.dense.weight"),
|
||||
("pre_logits.fc.bias", "pooler.dense.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
# if just the base model, we should remove "vit" from all keys that start with "vit"
|
||||
rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
|
||||
else:
|
||||
# layernorm + classification head
|
||||
rename_keys.extend(
|
||||
[
|
||||
("norm.weight", "vit.layernorm.weight"),
|
||||
("norm.bias", "vit.layernorm.bias"),
|
||||
("head.weight", "classifier.weight"),
|
||||
("head.bias", "classifier.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
return rename_keys
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_q_k_v(state_dict, config, base_model=False):
|
||||
for i in range(config.num_hidden_layers):
|
||||
if base_model:
|
||||
prefix = ""
|
||||
else:
|
||||
prefix = "vit."
|
||||
# read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
|
||||
in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
||||
: config.hidden_size, :
|
||||
]
|
||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
|
||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||
config.hidden_size : config.hidden_size * 2, :
|
||||
]
|
||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
|
||||
config.hidden_size : config.hidden_size * 2
|
||||
]
|
||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
||||
-config.hidden_size :, :
|
||||
]
|
||||
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
|
||||
|
||||
|
||||
def remove_classification_head_(state_dict):
|
||||
ignore_keys = ["head.weight", "head.bias"]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our ViT structure.
|
||||
"""
|
||||
|
||||
# define default ViT configuration
|
||||
config = ViTConfig()
|
||||
base_model = False
|
||||
# dataset (ImageNet-21k only or also fine-tuned on ImageNet 2012), patch_size and image_size
|
||||
if vit_name[-5:] == "in21k":
|
||||
base_model = True
|
||||
config.patch_size = int(vit_name[-12:-10])
|
||||
config.image_size = int(vit_name[-9:-6])
|
||||
else:
|
||||
config.num_labels = 1000
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
config.patch_size = int(vit_name[-6:-4])
|
||||
config.image_size = int(vit_name[-3:])
|
||||
# size of the architecture
|
||||
if vit_name[4:].startswith("small"):
|
||||
config.hidden_size = 768
|
||||
config.intermediate_size = 2304
|
||||
config.num_hidden_layers = 8
|
||||
config.num_attention_heads = 8
|
||||
if vit_name[4:].startswith("base"):
|
||||
pass
|
||||
elif vit_name[4:].startswith("large"):
|
||||
config.hidden_size = 1024
|
||||
config.intermediate_size = 4096
|
||||
config.num_hidden_layers = 24
|
||||
config.num_attention_heads = 16
|
||||
elif vit_name[4:].startswith("huge"):
|
||||
config.hidden_size = 1280
|
||||
config.intermediate_size = 5120
|
||||
config.num_hidden_layers = 32
|
||||
config.num_attention_heads = 16
|
||||
|
||||
# load original model from timm
|
||||
timm_model = timm.create_model(vit_name, pretrained=True)
|
||||
timm_model.eval()
|
||||
|
||||
# load state_dict of original model, remove and rename some keys
|
||||
state_dict = timm_model.state_dict()
|
||||
if base_model:
|
||||
remove_classification_head_(state_dict)
|
||||
rename_keys = create_rename_keys(config, base_model)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
read_in_q_k_v(state_dict, config, base_model)
|
||||
|
||||
# load HuggingFace model
|
||||
if vit_name[-5:] == "in21k":
|
||||
model = ViTModel(config).eval()
|
||||
else:
|
||||
model = ViTForImageClassification(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# Check outputs on an image, prepared by ViTFeatureExtractor
|
||||
feature_extractor = ViTFeatureExtractor(size=config.image_size)
|
||||
encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
|
||||
pixel_values = encoding["pixel_values"]
|
||||
outputs = model(pixel_values)
|
||||
|
||||
if base_model:
|
||||
timm_pooled_output = timm_model.forward_features(pixel_values)
|
||||
assert timm_pooled_output.shape == outputs.pooler_output.shape
|
||||
assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3)
|
||||
else:
|
||||
timm_logits = timm_model(pixel_values)
|
||||
assert timm_logits.shape == outputs.logits.shape
|
||||
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model {vit_name} to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--vit_name",
|
||||
default="vit_base_patch16_224",
|
||||
type=str,
|
||||
help="Name of the ViT timm model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_vit_checkpoint(args.vit_name, args.pytorch_dump_folder_path)
|
130
src/transformers/models/vit/feature_extraction_vit.py
Normal file
130
src/transformers/models/vit/feature_extraction_vit.py
Normal file
@ -0,0 +1,130 @@
|
||||
# coding=utf-8
|
||||
# Copyright Google AI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Feature extractor class for ViT."""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||
from ...file_utils import TensorType
|
||||
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
r"""
|
||||
Constructs a ViT feature extractor.
|
||||
|
||||
This feature extractor inherits from :class:`~transformers.FeatureExtractionMixin` which contains most of the main
|
||||
methods. Users should refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
image_mean (:obj:`int`, defaults to :obj:`[0.5, 0.5, 0.5]`):
|
||||
The sequence of means for each channel, to be used when normalizing images.
|
||||
image_std (:obj:`int`, defaults to :obj:`[0.5, 0.5, 0.5]`):
|
||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
||||
do_normalize (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to normalize the input with mean and standard deviation.
|
||||
do_resize (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to resize the input to a certain :obj:`size`.
|
||||
size (:obj:`int`, `optional`, defaults to 224):
|
||||
Resize the input to the given size. Only has an effect if :obj:`do_resize` is set to :obj:`True`.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(self, image_mean=None, image_std=None, do_normalize=True, do_resize=True, size=224, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.image_mean = [0.5, 0.5, 0.5]
|
||||
self.image_std = [0.5, 0.5, 0.5]
|
||||
self.do_normalize = do_normalize
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: Union[
|
||||
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
|
||||
],
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several image(s).
|
||||
|
||||
.. warning::
|
||||
|
||||
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
|
||||
PIL images.
|
||||
|
||||
Args:
|
||||
images (:obj:`PIL.Image.Image`, :obj:`np.ndarray`, :obj:`torch.Tensor`, :obj:`List[PIL.Image.Image]`, :obj:`List[np.ndarray]`, :obj:`List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||
number of channels, H and W are image height and width.
|
||||
|
||||
return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.s
|
||||
* :obj:`'jax'`: Return JAX :obj:`jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
:class:`~transformers.BatchFeature`: A :class:`~transformers.BatchFeature` with the following fields:
|
||||
|
||||
- **pixel_values** -- Pixel values to be fed to a model.
|
||||
"""
|
||||
# Input type checking for clearer error
|
||||
valid_images = False
|
||||
|
||||
# Check that images has a valid type
|
||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
||||
valid_images = True
|
||||
elif isinstance(images, (list, tuple)):
|
||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
||||
valid_images = True
|
||||
|
||||
if not valid_images:
|
||||
raise ValueError(
|
||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
|
||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(images, (list, tuple))
|
||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
||||
)
|
||||
|
||||
if not is_batched:
|
||||
images = [images]
|
||||
|
||||
# transformations (resizing + normalization)
|
||||
if self.do_resize and self.size is not None:
|
||||
images = [self.resize(image=image, size=self.size) for image in images]
|
||||
if self.do_normalize:
|
||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
||||
|
||||
# return as BatchFeature
|
||||
data = {"pixel_values": images}
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
return encoded_inputs
|
629
src/transformers/models/vit/modeling_vit.py
Normal file
629
src/transformers/models/vit/modeling_vit.py
Normal file
@ -0,0 +1,629 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 Google AI, Ross Weightman, The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch ViT model. """
|
||||
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import logging
|
||||
from .configuration_vit import ViTConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "ViTConfig"
|
||||
|
||||
VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"nielsr/vit-base-patch16-224",
|
||||
# See all ViT models at https://huggingface.co/models?filter=vit
|
||||
]
|
||||
|
||||
|
||||
# Inspired by
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
|
||||
# From PyTorch internals
|
||||
def to_2tuple(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return (x, x)
|
||||
|
||||
|
||||
# Based on timm implementation, which can be found here:
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
|
||||
|
||||
class ViTEmbeddings(nn.Module):
|
||||
"""
|
||||
Construct the CLS token, position and patch embeddings.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
self.patch_embeddings = PatchEmbeddings(
|
||||
image_size=config.image_size,
|
||||
patch_size=config.patch_size,
|
||||
num_channels=config.num_channels,
|
||||
embed_dim=config.hidden_size,
|
||||
)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
batch_size = pixel_values.shape[0]
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
# Based on timm implementation, which can be found here:
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
class PatchEmbeddings(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
|
||||
super().__init__()
|
||||
image_size = to_2tuple(image_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
batch_size, num_channels, height, width = pixel_values.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
if height != self.image_size[0] or width != self.image_size[1]:
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class ViTSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, head_mask=None, output_attentions=False):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class ViTSelfOutput(nn.Module):
|
||||
"""
|
||||
The residual connection is defined in VitLayer instead of here (as is the case with other models), due to the
|
||||
layernorm applied before each block.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ViTAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.attention = ViTSelfAttention(config)
|
||||
self.output = ViTSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
||||
)
|
||||
|
||||
# Prune linear layers
|
||||
self.attention.query = prune_linear_layer(self.attention.query, index)
|
||||
self.attention.key = prune_linear_layer(self.attention.key, index)
|
||||
self.attention.value = prune_linear_layer(self.attention.value, index)
|
||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||
|
||||
# Update hyper params and store pruned heads
|
||||
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
||||
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def forward(self, hidden_states, head_mask=None, output_attentions=False):
|
||||
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
||||
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class ViTIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ViTOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + input_tensor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ViTLayer(nn.Module):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = ViTAttention(config)
|
||||
self.intermediate = ViTIntermediate(config)
|
||||
self.output = ViTOutput(config)
|
||||
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states, head_mask=None, output_attentions=False):
|
||||
self_attention_outputs = self.attention(
|
||||
self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
|
||||
head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
# first residual connection
|
||||
hidden_states = attention_output + hidden_states
|
||||
|
||||
# in ViT, layernorm is also applied after self-attention
|
||||
layer_output = self.layernorm_after(hidden_states)
|
||||
|
||||
# TODO feedforward chunking not working for now
|
||||
# layer_output = apply_chunking_to_forward(
|
||||
# self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, layer_output
|
||||
# )
|
||||
|
||||
layer_output = self.intermediate(layer_output)
|
||||
|
||||
# second residual connection is done here
|
||||
layer_output = self.output(layer_output, hidden_states)
|
||||
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class ViTEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
head_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer_module),
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
|
||||
class ViTPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = ViTConfig
|
||||
base_model_prefix = "vit"
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
VIT_START_DOCSTRING = r"""
|
||||
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
|
||||
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.ViTConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
||||
weights.
|
||||
"""
|
||||
|
||||
VIT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
:class:`~transformers.ViTFeatureExtractor`. See :meth:`transformers.ViTFeatureExtractor.__call__` for
|
||||
details.
|
||||
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
VIT_START_DOCSTRING,
|
||||
)
|
||||
class ViTModel(ViTPreTrainedModel):
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embeddings = ViTEmbeddings(config)
|
||||
self.encoder = ViTEncoder(config)
|
||||
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.pooler = ViTPooler(config) if add_pooling_layer else None
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import ViTFeatureExtractor, ViTModel
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
|
||||
>>> model = ViTModel.from_pretrained('google/vit-base-patch16-224')
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(pixel_values)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
sequence_output = self.layernorm(sequence_output)
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class ViTPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
||||
the [CLS] token) e.g. for ImageNet.
|
||||
""",
|
||||
VIT_START_DOCSTRING,
|
||||
)
|
||||
class ViTForImageClassification(ViTPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.vit = ViTModel(config, add_pooling_layer=False)
|
||||
|
||||
# Classifier head
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
head_mask=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
|
||||
>>> model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits
|
||||
>>> # model predicts one of the 1000 ImageNet classes
|
||||
>>> predicted_class_idx = logits.argmax(-1).item()
|
||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.vit(
|
||||
pixel_values,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.classifier(sequence_output[:, 0, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
@ -302,6 +302,9 @@ def load_tf_weights_in_albert(*args, **kwargs):
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_MASKED_LM_MAPPING = None
|
||||
|
||||
|
||||
@ -2512,6 +2515,32 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs):
|
||||
requires_pytorch(load_tf_weights_in_transfo_xl)
|
||||
|
||||
|
||||
VIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class ViTForImageClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class ViTModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class ViTPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -5,3 +5,8 @@ from ..file_utils import requires_vision
|
||||
class ImageFeatureExtractionMixin:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_vision(self)
|
||||
|
||||
|
||||
class ViTFeatureExtractor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_vision(self)
|
||||
|
1003
src/transformers/utils/imagenet_classes.py
Normal file
1003
src/transformers/utils/imagenet_classes.py
Normal file
File diff suppressed because it is too large
Load Diff
221
tests/test_feature_extraction_vit.py
Normal file
221
tests/test_feature_extraction_vit.py
Normal file
@ -0,0 +1,221 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.file_utils import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
|
||||
from .test_feature_extraction_common import FeatureExtractionSavingTestMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import ViTFeatureExtractor
|
||||
|
||||
|
||||
class ViTFeatureExtractionTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
do_normalize=True,
|
||||
do_resize=True,
|
||||
size=18,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_normalize = do_normalize
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_normalize": self.do_normalize,
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
}
|
||||
|
||||
def prepare_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
"""
|
||||
|
||||
assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time"
|
||||
|
||||
if equal_resolution:
|
||||
image_inputs = []
|
||||
for i in range(self.batch_size):
|
||||
image_inputs.append(
|
||||
np.random.randint(
|
||||
255, size=(self.num_channels, self.max_resolution, self.max_resolution), dtype=np.uint8
|
||||
)
|
||||
)
|
||||
else:
|
||||
image_inputs = []
|
||||
for i in range(self.batch_size):
|
||||
width, height = np.random.choice(np.arange(self.min_resolution, self.max_resolution), 2)
|
||||
image_inputs.append(np.random.randint(255, size=(self.num_channels, width, height), dtype=np.uint8))
|
||||
|
||||
if not numpify and not torchify:
|
||||
# PIL expects the channel dimension as last dimension
|
||||
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||
|
||||
if torchify:
|
||||
image_inputs = [torch.from_numpy(x) for x in image_inputs]
|
||||
|
||||
return image_inputs
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
|
||||
|
||||
feature_extraction_class = ViTFeatureExtractor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.feature_extract_tester = ViTFeatureExtractionTester(self)
|
||||
|
||||
@property
|
||||
def feat_extract_dict(self):
|
||||
return self.feature_extract_tester.prepare_feat_extract_dict()
|
||||
|
||||
def test_feat_extract_properties(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
self.assertTrue(hasattr(feature_extractor, "image_mean"))
|
||||
self.assertTrue(hasattr(feature_extractor, "image_std"))
|
||||
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
|
||||
self.assertTrue(hasattr(feature_extractor, "do_resize"))
|
||||
self.assertTrue(hasattr(feature_extractor, "size"))
|
||||
|
||||
def test_batch_feature(self):
|
||||
pass
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random PIL images
|
||||
image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
1,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
1,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False, torchify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
1,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
@ -264,7 +264,9 @@ class ImageFeatureExtractionTester(unittest.TestCase):
|
||||
|
||||
# During the conversion rescale and channel first will be applied.
|
||||
expected = array.transpose(2, 0, 1).astype(np.float32) / 255.0
|
||||
expected = (expected - np.array(mean)[:, None, None]) / np.array(std)[:, None, None]
|
||||
np_mean = np.array(mean).astype(np.float32)[:, None, None]
|
||||
np_std = np.array(std).astype(np.float32)[:, None, None]
|
||||
expected = (expected - np_mean) / np_std
|
||||
self.assertTrue(np.array_equal(normalized_image, expected))
|
||||
|
||||
def test_normalize_array(self):
|
||||
|
@ -34,6 +34,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
@ -99,6 +100,7 @@ class ModelTesterMixin:
|
||||
elif model_class in [
|
||||
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(),
|
||||
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(),
|
||||
*MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.values(),
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
|
365
tests/test_modeling_vit.py
Normal file
365
tests/test_modeling_vit.py
Normal file
@ -0,0 +1,365 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch ViT model. """
|
||||
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import ViTConfig, ViTForImageClassification, ViTModel
|
||||
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import ViTFeatureExtractor
|
||||
|
||||
|
||||
class ViTModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
image_size=30,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
type_sequence_label_size=10,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
|
||||
config = ViTConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
model = ViTModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = to_2tuple(self.image_size)
|
||||
patch_size = to_2tuple(self.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||
config.num_labels = self.type_sequence_label_size
|
||||
model = ViTForImageClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values, labels=labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
pixel_values,
|
||||
labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class ViTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as ViT does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
ViTModel,
|
||||
ViTForImageClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ViTModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ViTConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
config = self.config_tester.config_class(**self.config_tester.inputs_dict)
|
||||
# we omit vocab_size since ViT does not use this
|
||||
self.config_tester.parent.assertTrue(hasattr(config, "hidden_size"))
|
||||
self.config_tester.parent.assertTrue(hasattr(config, "num_attention_heads"))
|
||||
self.config_tester.parent.assertTrue(hasattr(config, "num_hidden_layers"))
|
||||
|
||||
self.config_tester.create_and_test_config_to_json_string()
|
||||
self.config_tester.create_and_test_config_to_json_file()
|
||||
self.config_tester.create_and_test_config_from_and_save_pretrained()
|
||||
self.config_tester.create_and_test_config_with_num_labels()
|
||||
self.config_tester.check_config_can_be_init_without_params()
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
# ViT does not use inputs_embeds
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Module))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_len = num_patches + 1
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
||||
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
elif self.is_encoder_decoder:
|
||||
added_hidden_states = 2
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
# ViT has a different seq_length
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_length = num_patches + 1
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in VIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = ViTModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/cats.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_vision
|
||||
class ViTModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") if is_vision_available() else None
|
||||
|
||||
@slow
|
||||
def test_inference_image_classification_head(self):
|
||||
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").to(torch_device)
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
# currently failing
|
||||
# see https://discuss.pytorch.org/t/runtimeerror-expected-object-of-scalar-type-double-but-got-scalar-type-float-for-argument-2-weight/38961/2
|
||||
outputs = model(inputs["pixel_values"])
|
||||
# outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
Loading…
Reference in New Issue
Block a user