mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
VisualBERT (#10534)
* Init VisualBERT * Add cookie-cutter, Config, and Embeddings * Add preliminary Model * Add Bert analogous classes * Add basic code for NLVR, VQA, Flickr * Update Init * Fix VisualBert Downstream Models * Rename classifier to cls * Comment position_ids buffer * Remove sentence image predictor output * Update output dicts * Remove unnecessary files * Fix Auto Modeling * Fix transformers init * Add conversion script * Add conversion script * Fix docs * Update visualbert modelling * Update configuration * Style fixes * Add model and integration tests * Add all tests * Update model mapping * Add simple detector from original repository * Update docs and configs * Fix style * Fix style * Update docs * Fix style * Fix import issues in style * Fix style * Add changes from review * Fix style * Fix style * Update docs * Fix style * Fix style * Update docs/source/model_doc/visual_bert.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/visual_bert/modeling_visual_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update tests/test_modeling_visual_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/visual_bert/modeling_visual_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/visual_bert/modeling_visual_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/visual_bert/modeling_visual_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Add changes from review * Remove convert run script * Add changes from review * Update src/transformers/models/visual_bert/modeling_visual_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/visual_bert/modeling_visual_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/visual_bert/modeling_visual_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/visual_bert/modeling_visual_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/visual_bert/modeling_visual_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Add changes from review * Add changes from review * Add visual embedding example in docs * Fix "copied from" comments * Add changes from review * Fix error, style, checkpoints * Update docs * Fix integration tests * Fix style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
43f46aa7fd
commit
88ca6a231d
@ -251,6 +251,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
1. **[Vision Transformer (ViT)](https://huggingface.co/transformers/model_doc/vit.html)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
1. **[VisualBERT](https://huggingface.co/transformers/model_doc/visual_bert.html)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
|
||||
1. **[Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
|
||||
1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
|
@ -256,22 +256,25 @@ Supported models
|
||||
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
|
||||
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
|
||||
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
54. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
54. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
|
||||
Performant Baseline for Vision and Language <https://arxiv.org/pdf/1908.03557>`__ by Liunian Harold Li, Mark
|
||||
Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
|
||||
55. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
|
||||
Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
55. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
56. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
||||
56. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
57. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
||||
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
57. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
58. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
||||
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
||||
Zettlemoyer and Veselin Stoyanov.
|
||||
58. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
59. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
59. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
60. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
|
||||
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
|
||||
@ -389,6 +392,8 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| ViT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Wav2Vec2 | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
@ -537,6 +542,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
model_doc/tapas
|
||||
model_doc/transformerxl
|
||||
model_doc/vit
|
||||
model_doc/visual_bert
|
||||
model_doc/wav2vec2
|
||||
model_doc/xlm
|
||||
model_doc/xlmprophetnet
|
||||
|
128
docs/source/model_doc/visual_bert.rst
Normal file
128
docs/source/model_doc/visual_bert.rst
Normal file
@ -0,0 +1,128 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
VisualBERT
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The VisualBERT model was proposed in `VisualBERT: A Simple and Performant Baseline for Vision and Language
|
||||
<https://arxiv.org/pdf/1908.03557>`__ by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
|
||||
VisualBERT is a neural network trained on a variety of (image, text) pairs.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We propose VisualBERT, a simple and flexible framework for modeling a broad range of vision-and-language tasks.
|
||||
VisualBERT consists of a stack of Transformer layers that implicitly align elements of an input text and regions in an
|
||||
associated input image with self-attention. We further propose two visually-grounded language model objectives for
|
||||
pre-training VisualBERT on image caption data. Experiments on four vision-and-language tasks including VQA, VCR, NLVR2,
|
||||
and Flickr30K show that VisualBERT outperforms or rivals with state-of-the-art models while being significantly
|
||||
simpler. Further analysis demonstrates that VisualBERT can ground elements of language to image regions without any
|
||||
explicit supervision and is even sensitive to syntactic relationships, tracking, for example, associations between
|
||||
verbs and image regions corresponding to their arguments.*
|
||||
|
||||
Tips:
|
||||
|
||||
1. Most of the checkpoints provided work with the :class:`~transformers.VisualBertForPreTraining` configuration. Other
|
||||
checkpoints provided are the fine-tuned checkpoints for down-stream tasks - VQA ('visualbert-vqa'), VCR
|
||||
('visualbert-vcr'), NLVR2 ('visualbert-nlvr2'). Hence, if you are not working on these downstream tasks, it is
|
||||
recommended that you use the pretrained checkpoints.
|
||||
|
||||
2. For the VCR task, the authors use a fine-tuned detector for generating visual embeddings, for all the checkpoints.
|
||||
We do not provide the detector and its weights as a part of the package, but it will be available in the research
|
||||
projects, and the states can be loaded directly into the detector provided.
|
||||
|
||||
Usage
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
VisualBERT is a multi-modal vision and language model. It can be used for visual question answering, multiple choice,
|
||||
visual reasoning and region-to-phrase correspondence tasks. VisualBERT uses a BERT-like transformer to prepare
|
||||
embeddings for image-text pairs. Both the text and visual features are then projected to a latent space with identical
|
||||
dimension.
|
||||
|
||||
To feed images to the model, each image is passed through a pre-trained object detector and the regions and the
|
||||
bounding boxes are extracted. The authors use the features generated after passing these regions through a pre-trained
|
||||
CNN like ResNet as visual embeddings. They also add absolute position embeddings, and feed the resulting sequence of
|
||||
vectors to a standard BERT model. The text input is concatenated in the front of the visual embeddings in the embedding
|
||||
layer, and is expected to be bound by [CLS] and a [SEP] tokens, as in BERT. The segment IDs must also be set
|
||||
appropriately for the textual and visual parts.
|
||||
|
||||
The :class:`~transformers.BertTokenizer` is used to encode the text. A custom detector/feature extractor must be used
|
||||
to get the visual embeddings. For an example on how to generate visual embeddings, see the `colab notebook
|
||||
<https://colab.research.google.com/drive/1bLGxKdldwqnMVA5x4neY7-l_8fKGWQYI?usp=sharing>`__. The following example shows
|
||||
how to get the last hidden state using :class:`~transformers.VisualBertModel`:
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import torch
|
||||
>>> from transformers import BertTokenizer, VisualBertModel
|
||||
|
||||
>>> model = VisualBertModel.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
|
||||
>>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
>>> inputs = tokenizer("What is the man eating?", return_tensors="pt")
|
||||
>>> # this is a custom function that returns the visual embeddings given the image path
|
||||
>>> visual_embeds = get_visual_embeddings(image_path)
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
|
||||
This model was contributed by `gchhablani <https://huggingface.co/gchhablani>`__. The original code can be found `here
|
||||
<https://github.com/uclanlp/visualbert>`__.
|
||||
|
||||
VisualBertConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.VisualBertConfig
|
||||
:members:
|
||||
|
||||
VisualBertModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.VisualBertModel
|
||||
:members: forward
|
||||
|
||||
|
||||
VisualBertForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.VisualBertForPreTraining
|
||||
:members: forward
|
||||
|
||||
|
||||
VisualBertForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.VisualBertForQuestionAnswering
|
||||
:members: forward
|
||||
|
||||
|
||||
VisualBertForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.VisualBertForMultipleChoice
|
||||
:members: forward
|
||||
|
||||
|
||||
VisualBertForVisualReasoning
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.VisualBertForVisualReasoning
|
||||
:members: forward
|
||||
|
||||
|
||||
VisualBertForRegionToPhraseAlignment
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.VisualBertForRegionToPhraseAlignment
|
||||
:members: forward
|
@ -233,6 +233,7 @@ _import_structure = {
|
||||
"TransfoXLCorpus",
|
||||
"TransfoXLTokenizer",
|
||||
],
|
||||
"models.visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"],
|
||||
"models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
|
||||
"models.wav2vec2": [
|
||||
"WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
@ -996,6 +997,19 @@ if is_torch_available():
|
||||
"load_tf_weights_in_transfo_xl",
|
||||
]
|
||||
)
|
||||
_import_structure["models.visual_bert"].extend(
|
||||
[
|
||||
"VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"VisualBertForMultipleChoice",
|
||||
"VisualBertForPreTraining",
|
||||
"VisualBertForQuestionAnswering",
|
||||
"VisualBertForRegionToPhraseAlignment",
|
||||
"VisualBertForVisualReasoning",
|
||||
"VisualBertLayer",
|
||||
"VisualBertModel",
|
||||
"VisualBertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.vit"].extend(
|
||||
[
|
||||
"VIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -1702,6 +1716,7 @@ if TYPE_CHECKING:
|
||||
TransfoXLCorpus,
|
||||
TransfoXLTokenizer,
|
||||
)
|
||||
from .models.visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
|
||||
from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
from .models.wav2vec2 import (
|
||||
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@ -2338,6 +2353,17 @@ if TYPE_CHECKING:
|
||||
TransfoXLPreTrainedModel,
|
||||
load_tf_weights_in_transfo_xl,
|
||||
)
|
||||
from .models.visual_bert import ( # load_tf_weights_in_visual_bert,
|
||||
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
VisualBertForMultipleChoice,
|
||||
VisualBertForPreTraining,
|
||||
VisualBertForQuestionAnswering,
|
||||
VisualBertForRegionToPhraseAlignment,
|
||||
VisualBertForVisualReasoning,
|
||||
VisualBertLayer,
|
||||
VisualBertModel,
|
||||
VisualBertPreTrainedModel,
|
||||
)
|
||||
from .models.vit import (
|
||||
VIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
ViTForImageClassification,
|
||||
|
@ -74,6 +74,7 @@ from . import (
|
||||
t5,
|
||||
tapas,
|
||||
transfo_xl,
|
||||
visual_bert,
|
||||
vit,
|
||||
wav2vec2,
|
||||
xlm,
|
||||
|
@ -77,6 +77,7 @@ from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFI
|
||||
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
|
||||
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
||||
from ..visual_bert.configuration_visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
|
||||
from ..vit.configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
from ..wav2vec2.configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
|
||||
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
|
||||
@ -92,6 +93,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
(key, value)
|
||||
for pretrained_map in [
|
||||
# Add archive maps here
|
||||
VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@ -148,6 +150,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
CONFIG_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
("visual_bert", VisualBertConfig),
|
||||
("roformer", RoFormerConfig),
|
||||
("clip", CLIPConfig),
|
||||
("bigbird_pegasus", BigBirdPegasusConfig),
|
||||
@ -210,6 +213,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
MODEL_NAMES_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add full (and cased) model names here
|
||||
("visual_bert", "VisualBert"),
|
||||
("roformer", "RoFormer"),
|
||||
("clip", "CLIP"),
|
||||
("bigbird_pegasus", "BigBirdPegasus"),
|
||||
|
@ -266,6 +266,7 @@ from ..tapas.modeling_tapas import (
|
||||
TapasModel,
|
||||
)
|
||||
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
|
||||
from ..visual_bert.modeling_visual_bert import VisualBertForPreTraining, VisualBertModel
|
||||
from ..vit.modeling_vit import ViTForImageClassification, ViTModel
|
||||
from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2Model
|
||||
from ..xlm.modeling_xlm import (
|
||||
@ -349,6 +350,7 @@ from .configuration_auto import (
|
||||
T5Config,
|
||||
TapasConfig,
|
||||
TransfoXLConfig,
|
||||
VisualBertConfig,
|
||||
ViTConfig,
|
||||
Wav2Vec2Config,
|
||||
XLMConfig,
|
||||
@ -364,6 +366,7 @@ logger = logging.get_logger(__name__)
|
||||
MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
(VisualBertConfig, VisualBertModel),
|
||||
(RoFormerConfig, RoFormerModel),
|
||||
(CLIPConfig, CLIPModel),
|
||||
(BigBirdPegasusConfig, BigBirdPegasusModel),
|
||||
@ -425,6 +428,7 @@ MODEL_MAPPING = OrderedDict(
|
||||
MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for pre-training mapping
|
||||
(VisualBertConfig, VisualBertForPreTraining),
|
||||
(LayoutLMConfig, LayoutLMForMaskedLM),
|
||||
(RetriBertConfig, RetriBertModel),
|
||||
(T5Config, T5ForConditionalGeneration),
|
||||
|
74
src/transformers/models/visual_bert/__init__.py
Normal file
74
src/transformers/models/visual_bert/__init__.py
Normal file
@ -0,0 +1,74 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_visual_bert"] = [
|
||||
"VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"VisualBertForMultipleChoice",
|
||||
"VisualBertForPreTraining",
|
||||
"VisualBertForQuestionAnswering",
|
||||
"VisualBertForRegionToPhraseAlignment",
|
||||
"VisualBertForVisualReasoning",
|
||||
"VisualBertLayer",
|
||||
"VisualBertModel",
|
||||
"VisualBertPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_visual_bert import (
|
||||
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
VisualBertForMultipleChoice,
|
||||
VisualBertForPreTraining,
|
||||
VisualBertForQuestionAnswering,
|
||||
VisualBertForRegionToPhraseAlignment,
|
||||
VisualBertForVisualReasoning,
|
||||
VisualBertLayer,
|
||||
VisualBertModel,
|
||||
VisualBertPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
class _LazyModule(_BaseLazyModule):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
"""
|
||||
|
||||
__file__ = globals()["__file__"]
|
||||
__path__ = [os.path.dirname(__file__)]
|
||||
|
||||
def _get_module(self, module_name: str):
|
||||
return importlib.import_module("." + module_name, self.__name__)
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
|
145
src/transformers/models/visual_bert/configuration_visual_bert.py
Normal file
145
src/transformers/models/visual_bert/configuration_visual_bert.py
Normal file
@ -0,0 +1,145 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" VisualBERT model configuration """
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"uclanlp/visualbert-vqa": "https://huggingface.co/uclanlp/visualbert-vqa/resolve/main/config.json",
|
||||
"uclanlp/visualbert-vqa-pre": "https://huggingface.co/uclanlp/visualbert-vqa-pre/resolve/main/config.json",
|
||||
"uclanlp/visualbert-vqa-coco-pre": "https://huggingface.co/uclanlp/visualbert-vqa-coco-pre/resolve/main/config.json",
|
||||
"uclanlp/visualbert-vcr": "https://huggingface.co/uclanlp/visualbert-vcr/resolve/main/config.json",
|
||||
"uclanlp/visualbert-vcr-pre": "https://huggingface.co/uclanlp/visualbert-vcr-pre/resolve/main/config.json",
|
||||
"uclanlp/visualbert-vcr-coco-pre": "https://huggingface.co/uclanlp/visualbert-vcr-coco-pre/resolve/main/config.json",
|
||||
"uclanlp/visualbert-nlvr2": "https://huggingface.co/uclanlp/visualbert-nlvr2/resolve/main/config.json",
|
||||
"uclanlp/visualbert-nlvr2-pre": "https://huggingface.co/uclanlp/visualbert-nlvr2-pre/resolve/main/config.json",
|
||||
"uclanlp/visualbert-nlvr2-coco-pre": "https://huggingface.co/uclanlp/visualbert-nlvr2-coco-pre/resolve/main/config.json"
|
||||
# See all VisualBERT models at https://huggingface.co/models?filter=visual_bert
|
||||
}
|
||||
|
||||
|
||||
class VisualBertConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.VisualBertModel`. It is used
|
||||
to instantiate an VisualBERT model according to the specified arguments, defining the model architecture.
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to that of the VisualBERT
|
||||
`visualbert-vqa-coco-pre <https://huggingface.co/uclanlp/visualbert-vqa-coco-pre>`__ architecture.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 30522):
|
||||
Vocabulary size of the VisualBERT model. Defines the number of different tokens that can be represented by
|
||||
the :obj:`inputs_ids` passed when calling :class:`~transformers.VisualBertModel`. Vocabulary size of the
|
||||
model. Defines the different tokens that can be represented by the ``inputs_ids`` passed to the forward
|
||||
method of :class:`~transformers.VisualBertModel`.
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
visual_embedding_dim (:obj:`int`, `optional`, defaults to 512):
|
||||
Dimensionality of the visual embeddings to be passed to the model.
|
||||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
|
||||
The vocabulary size of the :obj:`token_type_ids` passed when calling
|
||||
:class:`~transformers.VisualBertModel`.
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
bypass_transformer (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the model should bypass the transformer for the visual embeddings. If set to :obj:`True`,
|
||||
the model directly concatenates the visual embeddings from :class:`~transformers.VisualBertEmbeddings` with
|
||||
text output from transformers, and then pass it to a self-attention layer.
|
||||
special_visual_initialize (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the visual token type and position type embedding weights should be initialized the same as
|
||||
the textual token type and positive type embeddings. When set to :obj:`True`, the weights of the textual
|
||||
token type and position type embeddings are copied to the respective visual embedding layers.
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import VisualBertModel, VisualBertConfig
|
||||
|
||||
>>> # Initializing a VisualBERT visualbert-vqa-coco-pre style configuration
|
||||
>>> configuration = VisualBertConfig.from_pretrained('visualbert-vqa-coco-pre')
|
||||
|
||||
>>> # Initializing a model from the visualbert-vqa-coco-pre style configuration
|
||||
>>> model = VisualBertModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
"""
|
||||
|
||||
model_type = "visual_bert"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
visual_embedding_dim=512,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
bypass_transformer=False,
|
||||
special_visual_initialize=True,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.visual_embedding_dim = visual_embedding_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.initializer_range = initializer_range
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.bypass_transformer = bypass_transformer
|
||||
self.special_visual_initialize = special_visual_initialize
|
@ -0,0 +1,150 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert VisualBert checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
VisualBertConfig,
|
||||
VisualBertForMultipleChoice,
|
||||
VisualBertForPreTraining,
|
||||
VisualBertForQuestionAnswering,
|
||||
VisualBertForVisualReasoning,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
rename_keys_prefix = [
|
||||
("bert.bert", "visual_bert"),
|
||||
("bert.cls", "cls"),
|
||||
("bert.classifier", "cls"),
|
||||
("token_type_embeddings_visual", "visual_token_type_embeddings"),
|
||||
("position_embeddings_visual", "visual_position_embeddings"),
|
||||
("projection", "visual_projection"),
|
||||
]
|
||||
|
||||
ACCEPTABLE_CHECKPOINTS = [
|
||||
"nlvr2_coco_pre_trained.th",
|
||||
"nlvr2_fine_tuned.th",
|
||||
"nlvr2_pre_trained.th",
|
||||
"vcr_coco_pre_train.th",
|
||||
"vcr_fine_tune.th",
|
||||
"vcr_pre_train.th",
|
||||
"vqa_coco_pre_trained.th",
|
||||
"vqa_fine_tuned.th",
|
||||
"vqa_pre_trained.th",
|
||||
]
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_path):
|
||||
sd = torch.load(checkpoint_path, map_location="cpu")
|
||||
return sd
|
||||
|
||||
|
||||
def get_new_dict(d, config, rename_keys_prefix=rename_keys_prefix):
|
||||
new_d = OrderedDict()
|
||||
new_d["visual_bert.embeddings.position_ids"] = torch.arange(config.max_position_embeddings).expand((1, -1))
|
||||
# detector_d = OrderedDict()
|
||||
for key in d:
|
||||
if "detector" in key:
|
||||
# detector_d[key.replace('detector.','')] = d[key]
|
||||
continue
|
||||
new_key = key
|
||||
for name_pair in rename_keys_prefix:
|
||||
new_key = new_key.replace(name_pair[0], name_pair[1])
|
||||
new_d[new_key] = d[key]
|
||||
if key == "bert.cls.predictions.decoder.weight":
|
||||
# Old bert code didn't have `decoder.bias`, but was added separately
|
||||
new_d["cls.predictions.decoder.bias"] = new_d["cls.predictions.bias"]
|
||||
return new_d
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_visual_bert_checkpoint(checkpoint_path, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our VisualBERT structure.
|
||||
"""
|
||||
|
||||
assert (
|
||||
checkpoint_path.split("/")[-1] in ACCEPTABLE_CHECKPOINTS
|
||||
), f"The checkpoint provided must be in {ACCEPTABLE_CHECKPOINTS}."
|
||||
|
||||
# Get Config
|
||||
if "pre" in checkpoint_path:
|
||||
model_type = "pretraining"
|
||||
if "vcr" in checkpoint_path:
|
||||
config_params = {"visual_embedding_dim": 512}
|
||||
elif "vqa_advanced" in checkpoint_path:
|
||||
config_params = {"visual_embedding_dim": 2048}
|
||||
elif "vqa" in checkpoint_path:
|
||||
config_params = {"visual_embedding_dim": 2048}
|
||||
elif "nlvr" in checkpoint_path:
|
||||
config_params = {"visual_embedding_dim": 1024}
|
||||
else:
|
||||
raise NotImplementedError(f"No implementation found for `{checkpoint_path}`.")
|
||||
else:
|
||||
if "vcr" in checkpoint_path:
|
||||
config_params = {"visual_embedding_dim": 512}
|
||||
model_type = "multichoice"
|
||||
elif "vqa_advanced" in checkpoint_path:
|
||||
config_params = {"visual_embedding_dim": 2048}
|
||||
model_type = "vqa_advanced"
|
||||
elif "vqa" in checkpoint_path:
|
||||
config_params = {"visual_embedding_dim": 2048, "num_labels": 3129}
|
||||
model_type = "vqa"
|
||||
elif "nlvr" in checkpoint_path:
|
||||
config_params = {
|
||||
"visual_embedding_dim": 1024,
|
||||
"num_labels": 2,
|
||||
}
|
||||
model_type = "nlvr"
|
||||
|
||||
config = VisualBertConfig(**config_params)
|
||||
|
||||
# Load State Dict
|
||||
state_dict = load_state_dict(checkpoint_path)
|
||||
|
||||
new_state_dict = get_new_dict(state_dict, config)
|
||||
|
||||
if model_type == "pretraining":
|
||||
model = VisualBertForPreTraining(config)
|
||||
elif model_type == "vqa":
|
||||
model = VisualBertForQuestionAnswering(config)
|
||||
elif model_type == "nlvr":
|
||||
model = VisualBertForVisualReasoning(config)
|
||||
elif model_type == "multichoice":
|
||||
model = VisualBertForMultipleChoice(config)
|
||||
|
||||
model.load_state_dict(new_state_dict)
|
||||
# Save Checkpoints
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("orig_checkpoint_path", type=str, help="A path to .th on local filesystem.")
|
||||
parser.add_argument("pytorch_dump_folder_path", type=str, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
convert_visual_bert_checkpoint(args.orig_checkpoint_path, args.pytorch_dump_folder_path)
|
1559
src/transformers/models/visual_bert/modeling_visual_bert.py
Executable file
1559
src/transformers/models/visual_bert/modeling_visual_bert.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -2864,6 +2864,65 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_transfo_xl, ["torch"])
|
||||
|
||||
|
||||
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class VisualBertForMultipleChoice:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VisualBertForPreTraining:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VisualBertForQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VisualBertForRegionToPhraseAlignment:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VisualBertForVisualReasoning:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VisualBertLayer:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VisualBertModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VisualBertPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
VIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
689
tests/test_modeling_visual_bert.py
Normal file
689
tests/test_modeling_visual_bert.py
Normal file
@ -0,0 +1,689 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch VisualBERT model. """
|
||||
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from tests.test_modeling_common import floats_tensor
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
VisualBertConfig,
|
||||
VisualBertForMultipleChoice,
|
||||
VisualBertForPreTraining,
|
||||
VisualBertForQuestionAnswering,
|
||||
VisualBertForRegionToPhraseAlignment,
|
||||
VisualBertForVisualReasoning,
|
||||
VisualBertModel,
|
||||
)
|
||||
from transformers.models.visual_bert.modeling_visual_bert import VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
class VisualBertModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
visual_seq_length=5,
|
||||
is_training=True,
|
||||
use_attention_mask=True,
|
||||
use_visual_attention_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_visual_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
visual_embedding_dim=20,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.visual_seq_length = visual_seq_length
|
||||
self.is_training = is_training
|
||||
self.use_attention_mask = use_attention_mask
|
||||
self.use_visual_attention_mask = use_visual_attention_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_visual_token_type_ids = use_visual_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.visual_embedding_dim = visual_embedding_dim
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config(self):
|
||||
return VisualBertConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
visual_embedding_dim=self.visual_embedding_dim,
|
||||
num_labels=self.num_labels,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
visual_embeds = floats_tensor([self.batch_size, self.visual_seq_length, self.visual_embedding_dim])
|
||||
|
||||
attention_mask = None
|
||||
if self.use_attention_mask:
|
||||
attention_mask = torch.ones((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)
|
||||
|
||||
visual_attention_mask = None
|
||||
if self.use_visual_attention_mask:
|
||||
visual_attention_mask = torch.ones(
|
||||
(self.batch_size, self.visual_seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
visual_token_type_ids = None
|
||||
if self.use_visual_token_type_ids:
|
||||
visual_token_type_ids = ids_tensor([self.batch_size, self.visual_seq_length], self.type_vocab_size)
|
||||
|
||||
config = self.prepare_config()
|
||||
return config, {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"visual_embeds": visual_embeds,
|
||||
"visual_token_type_ids": visual_token_type_ids,
|
||||
"visual_attention_mask": visual_attention_mask,
|
||||
}
|
||||
|
||||
def prepare_config_and_inputs_for_pretraining(self):
|
||||
masked_lm_labels = None
|
||||
sentence_image_labels = None
|
||||
|
||||
if self.use_labels:
|
||||
masked_lm_labels = ids_tensor([self.batch_size, self.seq_length + self.visual_seq_length], self.vocab_size)
|
||||
sentence_image_labels = ids_tensor(
|
||||
[self.batch_size],
|
||||
self.type_sequence_label_size,
|
||||
)
|
||||
|
||||
config, input_dict = self.prepare_config_and_inputs_for_common()
|
||||
|
||||
input_dict.update({"labels": masked_lm_labels, "sentence_image_labels": sentence_image_labels})
|
||||
|
||||
return config, input_dict
|
||||
|
||||
def prepare_config_and_inputs_for_multiple_choice(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.num_choices, self.seq_length], self.vocab_size)
|
||||
visual_embeds = floats_tensor(
|
||||
[self.batch_size, self.num_choices, self.visual_seq_length, self.visual_embedding_dim]
|
||||
)
|
||||
|
||||
attention_mask = None
|
||||
if self.use_attention_mask:
|
||||
attention_mask = torch.ones(
|
||||
(self.batch_size, self.num_choices, self.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
||||
visual_attention_mask = None
|
||||
if self.use_visual_attention_mask:
|
||||
visual_attention_mask = torch.ones(
|
||||
(self.batch_size, self.num_choices, self.visual_seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.num_choices, self.seq_length], self.type_vocab_size)
|
||||
|
||||
visual_token_type_ids = None
|
||||
if self.use_visual_token_type_ids:
|
||||
visual_token_type_ids = ids_tensor(
|
||||
[self.batch_size, self.num_choices, self.visual_seq_length], self.type_vocab_size
|
||||
)
|
||||
|
||||
labels = None
|
||||
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.prepare_config()
|
||||
return config, {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"visual_embeds": visual_embeds,
|
||||
"visual_token_type_ids": visual_token_type_ids,
|
||||
"visual_attention_mask": visual_attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
def prepare_config_and_inputs_for_vqa(self):
|
||||
vqa_labels = None
|
||||
|
||||
if self.use_labels:
|
||||
vqa_labels = floats_tensor([self.batch_size, self.num_labels])
|
||||
|
||||
config, input_dict = self.prepare_config_and_inputs_for_common()
|
||||
|
||||
input_dict.update({"labels": vqa_labels})
|
||||
return config, input_dict
|
||||
|
||||
def prepare_config_and_inputs_for_nlvr(self):
|
||||
nlvr_labels = None
|
||||
|
||||
if self.use_labels:
|
||||
nlvr_labels = ids_tensor([self.batch_size], self.num_labels)
|
||||
|
||||
config, input_dict = self.prepare_config_and_inputs_for_common()
|
||||
|
||||
input_dict.update({"labels": nlvr_labels})
|
||||
return config, input_dict
|
||||
|
||||
def prepare_config_and_inputs_for_flickr(self):
|
||||
region_to_phrase_position = torch.cat(
|
||||
(
|
||||
ids_tensor([self.batch_size, self.seq_length], self.visual_seq_length),
|
||||
torch.ones(self.batch_size, self.visual_seq_length, dtype=torch.long, device=torch_device) * -1,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
flickr_labels = None
|
||||
if self.use_labels:
|
||||
flickr_labels = floats_tensor(
|
||||
[self.batch_size, self.seq_length + self.visual_seq_length, self.visual_seq_length]
|
||||
)
|
||||
|
||||
config, input_dict = self.prepare_config_and_inputs_for_common()
|
||||
|
||||
input_dict.update({"region_to_phrase_position": region_to_phrase_position, "labels": flickr_labels})
|
||||
return config, input_dict
|
||||
|
||||
def create_and_check_model(self, config, input_dict):
|
||||
model = VisualBertModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(**input_dict)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape,
|
||||
(self.batch_size, self.seq_length + self.visual_seq_length, self.hidden_size),
|
||||
)
|
||||
|
||||
def create_and_check_for_pretraining(self, config, input_dict):
|
||||
model = VisualBertForPreTraining(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(**input_dict)
|
||||
self.parent.assertEqual(
|
||||
result.prediction_logits.shape,
|
||||
(self.batch_size, self.seq_length + self.visual_seq_length, self.vocab_size),
|
||||
)
|
||||
|
||||
def create_and_check_for_vqa(self, config, input_dict):
|
||||
model = VisualBertForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(**input_dict)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_for_multiple_choice(self, config, input_dict):
|
||||
model = VisualBertForMultipleChoice(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(**input_dict)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def create_and_check_for_nlvr(self, config, input_dict):
|
||||
model = VisualBertForVisualReasoning(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(**input_dict)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_for_flickr(self, config, input_dict):
|
||||
model = VisualBertForRegionToPhraseAlignment(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(**input_dict)
|
||||
self.parent.assertEqual(
|
||||
result.logits.shape, (self.batch_size, self.seq_length + self.visual_seq_length, self.visual_seq_length)
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class VisualBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
VisualBertModel,
|
||||
VisualBertForMultipleChoice,
|
||||
VisualBertForVisualReasoning,
|
||||
VisualBertForRegionToPhraseAlignment,
|
||||
VisualBertForQuestionAnswering,
|
||||
VisualBertForPreTraining,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_torchscript = False
|
||||
test_pruning = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
if model_class == VisualBertForMultipleChoice:
|
||||
for key in inputs_dict.keys():
|
||||
value = inputs_dict[key]
|
||||
if isinstance(value, torch.Tensor) and value.ndim > 1:
|
||||
if key != "visual_embeds":
|
||||
inputs_dict[key] = (
|
||||
inputs_dict[key].unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
||||
)
|
||||
else:
|
||||
inputs_dict[key] = (
|
||||
inputs_dict[key]
|
||||
.unsqueeze(1)
|
||||
.expand(-1, self.model_tester.num_choices, -1, self.model_tester.visual_embedding_dim)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
elif model_class == VisualBertForRegionToPhraseAlignment:
|
||||
total_length = self.model_tester.seq_length + self.model_tester.visual_seq_length
|
||||
batch_size = self.model_tester.batch_size
|
||||
inputs_dict["region_to_phrase_position"] = torch.zeros(
|
||||
(batch_size, total_length),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
if return_labels:
|
||||
if model_class == VisualBertForMultipleChoice:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class == VisualBertForPreTraining:
|
||||
total_length = self.model_tester.seq_length + self.model_tester.visual_seq_length
|
||||
batch_size = self.model_tester.batch_size
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(batch_size, total_length),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
inputs_dict["sentence_image_labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
||||
# Flickr expects float labels
|
||||
elif model_class == VisualBertForRegionToPhraseAlignment:
|
||||
batch_size = self.model_tester.batch_size
|
||||
total_length = self.model_tester.seq_length + self.model_tester.visual_seq_length
|
||||
|
||||
inputs_dict["labels"] = torch.ones(
|
||||
(
|
||||
batch_size,
|
||||
total_length,
|
||||
self.model_tester.visual_seq_length,
|
||||
),
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
# VQA expects float labels
|
||||
elif model_class == VisualBertForQuestionAnswering:
|
||||
inputs_dict["labels"] = torch.ones(
|
||||
(self.model_tester.batch_size, self.model_tester.num_labels),
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
elif model_class == VisualBertForVisualReasoning:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size), dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = VisualBertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=VisualBertConfig, hidden_size=37)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||
visual_seq_len = getattr(self.model_tester, "visual_seq_length", None)
|
||||
|
||||
encoder_seq_length = (seq_len if seq_len is not None else 0) + (
|
||||
visual_seq_len if visual_seq_len is not None else 0
|
||||
)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
||||
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
elif self.is_encoder_decoder:
|
||||
added_hidden_states = 2
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
if chunk_length is not None:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-4:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
|
||||
)
|
||||
else:
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
if hasattr(self.model_tester, "encoder_seq_length"):
|
||||
seq_length = self.model_tester.encoder_seq_length
|
||||
if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
|
||||
seq_length = seq_length * self.model_tester.chunk_length
|
||||
else:
|
||||
seq_length = self.model_tester.seq_length + self.model_tester.visual_seq_length
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_for_pretraining(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_pretraining()
|
||||
self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
|
||||
|
||||
def test_model_for_vqa(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_vqa()
|
||||
self.model_tester.create_and_check_for_vqa(*config_and_inputs)
|
||||
|
||||
def test_model_for_nlvr(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_nlvr()
|
||||
self.model_tester.create_and_check_for_nlvr(*config_and_inputs)
|
||||
|
||||
def test_model_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_multiple_choice()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_model_for_flickr(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_flickr()
|
||||
self.model_tester.create_and_check_for_flickr(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = VisualBertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
class VisualBertModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_vqa_coco_pre(self):
|
||||
model = VisualBertForPreTraining.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
|
||||
|
||||
input_ids = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.long).reshape(1, -1)
|
||||
token_type_ids = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long).reshape(1, -1)
|
||||
visual_embeds = torch.ones(size=(1, 10, 2048), dtype=torch.float32) * 0.5
|
||||
visual_token_type_ids = torch.ones(size=(1, 10), dtype=torch.long)
|
||||
attention_mask = torch.tensor([1] * 6).reshape(1, -1)
|
||||
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
|
||||
|
||||
output = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
visual_embeds=visual_embeds,
|
||||
visual_attention_mask=visual_attention_mask,
|
||||
visual_token_type_ids=visual_token_type_ids,
|
||||
)
|
||||
|
||||
vocab_size = 30522
|
||||
|
||||
expected_shape = torch.Size((1, 16, vocab_size))
|
||||
self.assertEqual(output.prediction_logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[[-5.1858, -5.1903, -4.9142], [-6.2214, -5.9238, -5.8381], [-6.3027, -5.9939, -5.9297]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output.prediction_logits[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
expected_shape_2 = torch.Size((1, 2))
|
||||
self.assertEqual(output.seq_relationship_logits.shape, expected_shape_2)
|
||||
|
||||
expected_slice_2 = torch.tensor([[0.7393, 0.1754]])
|
||||
|
||||
self.assertTrue(torch.allclose(output.seq_relationship_logits, expected_slice_2, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_vqa(self):
|
||||
model = VisualBertForQuestionAnswering.from_pretrained("uclanlp/visualbert-vqa")
|
||||
|
||||
input_ids = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.long).reshape(1, -1)
|
||||
token_type_ids = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long).reshape(1, -1)
|
||||
visual_embeds = torch.ones(size=(1, 10, 2048), dtype=torch.float32) * 0.5
|
||||
visual_token_type_ids = torch.ones(size=(1, 10), dtype=torch.long)
|
||||
attention_mask = torch.tensor([1] * 6).reshape(1, -1)
|
||||
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
|
||||
|
||||
output = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
visual_embeds=visual_embeds,
|
||||
visual_attention_mask=visual_attention_mask,
|
||||
visual_token_type_ids=visual_token_type_ids,
|
||||
)
|
||||
|
||||
# vocab_size = 30522
|
||||
|
||||
expected_shape = torch.Size((1, 3129))
|
||||
self.assertEqual(output.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-8.9898, 3.0803, -1.8016, 2.4542, -8.3420, -2.0224, -3.3124, -4.4139, -3.1491, -3.8997]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output.logits[:, :10], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_nlvr(self):
|
||||
model = VisualBertForVisualReasoning.from_pretrained("uclanlp/visualbert-nlvr2")
|
||||
|
||||
input_ids = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.long).reshape(1, -1)
|
||||
token_type_ids = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long).reshape(1, -1)
|
||||
visual_embeds = torch.ones(size=(1, 10, 1024), dtype=torch.float32) * 0.5
|
||||
visual_token_type_ids = torch.ones(size=(1, 10), dtype=torch.long)
|
||||
attention_mask = torch.tensor([1] * 6).reshape(1, -1)
|
||||
visual_attention_mask = torch.tensor([1] * 10).reshape(1, -1)
|
||||
|
||||
output = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
visual_embeds=visual_embeds,
|
||||
visual_attention_mask=visual_attention_mask,
|
||||
visual_token_type_ids=visual_token_type_ids,
|
||||
)
|
||||
|
||||
# vocab_size = 30522
|
||||
|
||||
expected_shape = torch.Size((1, 2))
|
||||
self.assertEqual(output.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([[-1.1436, 0.8900]])
|
||||
|
||||
self.assertTrue(torch.allclose(output.logits, expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_vcr(self):
|
||||
model = VisualBertForMultipleChoice.from_pretrained("uclanlp/visualbert-vcr")
|
||||
|
||||
input_ids = torch.tensor([[[1, 2, 3, 4, 5, 6] for i in range(4)]], dtype=torch.long)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
token_type_ids = torch.ones_like(input_ids)
|
||||
|
||||
visual_embeds = torch.ones(size=(1, 4, 10, 512), dtype=torch.float32) * 0.5
|
||||
visual_token_type_ids = torch.ones(size=(1, 4, 10), dtype=torch.long)
|
||||
visual_attention_mask = torch.ones_like(visual_token_type_ids)
|
||||
|
||||
output = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
visual_embeds=visual_embeds,
|
||||
visual_attention_mask=visual_attention_mask,
|
||||
visual_token_type_ids=visual_token_type_ids,
|
||||
)
|
||||
|
||||
# vocab_size = 30522
|
||||
|
||||
expected_shape = torch.Size((1, 4))
|
||||
self.assertEqual(output.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([[-7.7697, -7.7697, -7.7697, -7.7697]])
|
||||
|
||||
self.assertTrue(torch.allclose(output.logits, expected_slice, atol=1e-4))
|
@ -118,6 +118,10 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"XLMForQuestionAnswering",
|
||||
"XLNetForQuestionAnswering",
|
||||
"SeparableConv1D",
|
||||
"VisualBertForRegionToPhraseAlignment",
|
||||
"VisualBertForVisualReasoning",
|
||||
"VisualBertForQuestionAnswering",
|
||||
"VisualBertForMultipleChoice",
|
||||
]
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
|
Loading…
Reference in New Issue
Block a user