mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Add TrOCR + VisionEncoderDecoderModel (#13874)
* First draft * Update self-attention of RoBERTa as proposition * Improve conversion script * Add TrOCR decoder-only model * More improvements * Make forward pass with pretrained weights work * More improvements * Some more improvements * More improvements * Make conversion work * Clean up print statements * Add documentation, processor * Add test files * Small improvements * Some more improvements * Make fix-copies, improve docs * Make all vision encoder decoder model tests pass * Make conversion script support other models * Update URL for OCR image * Update conversion script * Fix style & quality * Add support for the large-printed model * Fix some issues * Add print statement for debugging * Add print statements for debugging * Make possible fix for sinusoidal embedding * Further debugging * Potential fix v2 * Add more print statements for debugging * Add more print statements for debugging * Deubg more * Comment out print statements * Make conversion of large printed model possible, address review comments * Make it possible to convert the stage1 checkpoints * Clean up code, apply suggestions from code review * Apply suggestions from code review, use Microsoft models in tests * Rename encoder_hidden_size to cross_attention_hidden_size * Improve docs
This commit is contained in:
parent
61f6426269
commit
408b2d2bd0
@ -267,7 +267,6 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[RemBERT](https://huggingface.co/transformers/model_doc/rembert.html)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
|
||||
1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
1. **[RoFormer](https://huggingface.co/transformers/model_doc/roformer.html)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
1. **[SpeechEncoderDecoder](https://huggingface.co/transformers/model_doc/speechencoderdecoder.html)**
|
||||
1. **[SpeechToTextTransformer](https://huggingface.co/transformers/model_doc/speech_to_text.html)** (from Facebook), released together with the paper [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
|
||||
1. **[SpeechToTextTransformer2](https://huggingface.co/transformers/model_doc/speech_to_text_2.html)** (from Facebook), released together with the paper [Large-Scale Self- and Semi-Supervised Learning for Speech Translation](https://arxiv.org/abs/2104.06678) by Changhan Wang, Anne Wu, Juan Pino, Alexei Baevski, Michael Auli, Alexis Conneau.
|
||||
1. **[Splinter](https://huggingface.co/transformers/model_doc/splinter.html)** (from Tel Aviv University), released together with the paper [Few-Shot Question Answering by Pretraining Span Selection](https://arxiv.org/abs/2101.00438) by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy.
|
||||
@ -276,6 +275,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[T5v1.1](https://huggingface.co/transformers/model_doc/t5v1.1.html)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
1. **[TrOCR](https://huggingface.co/transformers/master/model_doc/trocr.html)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
|
||||
1. **[Vision Transformer (ViT)](https://huggingface.co/transformers/model_doc/vit.html)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
1. **[VisualBERT](https://huggingface.co/transformers/model_doc/visual_bert.html)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
|
||||
1. **[Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
|
@ -298,7 +298,9 @@ conda install -c huggingface transformers
|
||||
1. **[T5v1.1](https://huggingface.co/transformers/model_doc/t5v1.1.html)** (来自 Google AI) 伴随论文 [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) 由 Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu 发布。
|
||||
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (来自 Google AI) 伴随论文 [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) 由 Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos 发布。
|
||||
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (来自 Google/CMU) 伴随论文 [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) 由 Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov 发布。
|
||||
1. **[TrOCR](https://huggingface.co/transformers/master/model_doc/trocr.html)** (来自 Microsoft) 伴随论文 [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) 由 Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei 发布。
|
||||
1. **[Vision Transformer (ViT)](https://huggingface.co/transformers/model_doc/vit.html)** (来自 Google AI) 伴随论文 [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) 由 Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby 发布。
|
||||
1. **[VisionEncoderDecoder](https://huggingface.co/transformers/model_doc/visionencoderdecoder.html)**
|
||||
1. **[VisualBERT](https://huggingface.co/transformers/model_doc/visual_bert.html)** (来自 UCLA NLP) 伴随论文 [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) 由 Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang 发布。
|
||||
1. **[Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html)** (来自 Facebook AI) 伴随论文 [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) 由 Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli 发布。
|
||||
1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (来自 Facebook) 伴随论文 [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) 由 Guillaume Lample and Alexis Conneau 发布。
|
||||
|
@ -310,7 +310,9 @@ conda install -c huggingface transformers
|
||||
1. **[T5v1.1](https://huggingface.co/transformers/model_doc/t5v1.1.html)** (from Google AI) released with the paper [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
1. **[TrOCR](https://huggingface.co/transformers/master/model_doc/trocr.html)** (from Microsoft) released with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
|
||||
1. **[Vision Transformer (ViT)](https://huggingface.co/transformers/model_doc/vit.html)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
1. **[VisionEncoderDecoder](https://huggingface.co/transformers/model_doc/visionencoderdecoder.html)**
|
||||
1. **[VisualBERT](https://huggingface.co/transformers/model_doc/visual_bert.html)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
|
||||
1. **[Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
|
||||
|
@ -268,33 +268,36 @@ Supported models
|
||||
57. :doc:`RoFormer <model_doc/roformer>` (from ZhuiyiTechnology), released together with the paper a `RoFormer:
|
||||
Enhanced Transformer with Rotary Position Embedding <https://arxiv.org/pdf/2104.09864v1.pdf>`__ by Jianlin Su and
|
||||
Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
58. :doc:`SpeechEncoderDecoder <model_doc/speechencoderdecoder>`
|
||||
59. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
|
||||
58. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
|
||||
`fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun
|
||||
Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
|
||||
60. :doc:`SpeechToTextTransformer2 <model_doc/speech_to_text_2>` (from Facebook), released together with the paper
|
||||
59. :doc:`SpeechToTextTransformer2 <model_doc/speech_to_text_2>` (from Facebook), released together with the paper
|
||||
`Large-Scale Self- and Semi-Supervised Learning for Speech Translation <https://arxiv.org/abs/2104.06678>`__ by
|
||||
Changhan Wang, Anne Wu, Juan Pino, Alexei Baevski, Michael Auli, Alexis Conneau.
|
||||
61. :doc:`Splinter <model_doc/splinter>` (from Tel Aviv University), released together with the paper `Few-Shot
|
||||
60. :doc:`Splinter <model_doc/splinter>` (from Tel Aviv University), released together with the paper `Few-Shot
|
||||
Question Answering by Pretraining Span Selection <https://arxiv.org/abs/2101.00438>`__ by Ori Ram, Yuval Kirstain,
|
||||
Jonathan Berant, Amir Globerson, Omer Levy.
|
||||
62. :doc:`SqueezeBert <model_doc/squeezebert>` (from Berkeley) released with the paper `SqueezeBERT: What can computer
|
||||
61. :doc:`SqueezeBert <model_doc/squeezebert>` (from Berkeley) released with the paper `SqueezeBERT: What can computer
|
||||
vision teach NLP about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola,
|
||||
Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
|
||||
63. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
62. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
|
||||
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
64. :doc:`T5v1.1 <model_doc/t5v1.1>` (from Google AI) released in the repository
|
||||
63. :doc:`T5v1.1 <model_doc/t5v1.1>` (from Google AI) released in the repository
|
||||
`google-research/text-to-text-transfer-transformer
|
||||
<https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511>`__ by
|
||||
Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi
|
||||
Zhou and Wei Li and Peter J. Liu.
|
||||
65. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
64. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
|
||||
Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
66. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
65. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
||||
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
66. `TrOCR <https://huggingface.co/transformers/master/model_doc/trocr.html>`__ (from Microsoft), released together
|
||||
with the paper `TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models
|
||||
<https://arxiv.org/abs/2109.10282>`__ by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang,
|
||||
Zhoujun Li, Furu Wei.
|
||||
67. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
|
||||
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
|
||||
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
|
||||
@ -459,6 +462,10 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Vision Encoder decoder | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| ViT | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
@ -623,6 +630,8 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
model_doc/t5v1.1
|
||||
model_doc/tapas
|
||||
model_doc/transformerxl
|
||||
model_doc/trocr
|
||||
model_doc/visionencoderdecoder
|
||||
model_doc/vit
|
||||
model_doc/visual_bert
|
||||
model_doc/wav2vec2
|
||||
|
95
docs/source/model_doc/trocr.rst
Normal file
95
docs/source/model_doc/trocr.rst
Normal file
@ -0,0 +1,95 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
TrOCR
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The TrOCR model was proposed in `TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models
|
||||
<https://arxiv.org/abs/2109.10282>`__ by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang,
|
||||
Zhoujun Li, Furu Wei. TrOCR consists of an image Transformer encoder and an autoregressive text Transformer decoder to
|
||||
perform `optical character recognition (OCR) <https://en.wikipedia.org/wiki/Optical_character_recognition>`__.
|
||||
|
||||
Please refer to the :doc:`VisionEncoderDecoder <visionencoderdecoder>` class on how to use this model.
|
||||
|
||||
This model was contributed by `Niels Rogge <https://huggingface.co/nielsr>`__.
|
||||
|
||||
The original code can be found `here
|
||||
<https://github.com/microsoft/unilm/tree/6f60612e7cc86a2a1ae85c47231507a587ab4e01/trocr>`__.
|
||||
|
||||
|
||||
Tips:
|
||||
|
||||
- TrOCR is pre-trained in 2 stages before being fine-tuned on downstream datasets. It achieves state-of-the-art results
|
||||
on both printed (e.g. the `SROIE dataset <https://paperswithcode.com/dataset/sroie>`__) and handwritten (e.g. the
|
||||
`IAM Handwriting dataset <https://fki.tic.heia-fr.ch/databases/iam-handwriting-database>`__) text recognition tasks.
|
||||
For more information, see the `official models <https://huggingface.co/models?other=trocr>`__.
|
||||
- TrOCR is always used within the :doc:`VisionEncoderDecoder <visionencoderdecoder>` framework.
|
||||
|
||||
Inference
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
TrOCR's :class:`~transformers.VisionEncoderDecoderModel` model accepts images as input and makes use of
|
||||
:func:`~transformers.generation_utils.GenerationMixin.generate` to autoregressively generate text given the input
|
||||
image.
|
||||
|
||||
The :class:`~transformers.ViTFeatureExtractor` class is responsible for preprocessing the input image and
|
||||
:class:`~transformers.RobertaTokenizer` decodes the generated target tokens to the target string. The
|
||||
:class:`~transformers.TrOCRProcessor` wraps :class:`~transformers.ViTFeatureExtractor` and
|
||||
:class:`~transformers.RobertaTokenizer` into a single instance to both extract the input features and decode the
|
||||
predicted token ids.
|
||||
|
||||
- Step-by-step Optical Character Recognition (OCR)
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
|
||||
>>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
|
||||
|
||||
>>> # load image from the IAM dataset
|
||||
>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
|
||||
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
>>> generated_ids = model.generate(pixel_values)
|
||||
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
|
||||
See the `model hub <https://huggingface.co/models?filter=trocr>`__ to look for TrOCR checkpoints.
|
||||
|
||||
|
||||
TrOCRConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TrOCRConfig
|
||||
:members:
|
||||
|
||||
|
||||
TrOCRProcessor
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TrOCRProcessor
|
||||
:members: __call__, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor
|
||||
|
||||
|
||||
TrOCRForCausalLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TrOCRForCausalLM
|
||||
:members: forward
|
41
docs/source/model_doc/visionencoderdecoder.rst
Normal file
41
docs/source/model_doc/visionencoderdecoder.rst
Normal file
@ -0,0 +1,41 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
Vision Encoder Decoder Models
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
The :class:`~transformers.VisionEncoderDecoderModel` can be used to initialize an image-to-text-sequence model with any
|
||||
pretrained vision autoencoding model as the encoder (*e.g.* :doc:`ViT <vit>`, :doc:`BEiT <beit>`, :doc:`DeiT <deit>`)
|
||||
and any pretrained language model as the decoder (*e.g.* :doc:`RoBERTa <roberta>`, :doc:`GPT2 <gpt2>`, :doc:`BERT
|
||||
<bert>`).
|
||||
|
||||
The effectiveness of initializing image-to-text-sequence models with pretrained checkpoints has been shown in (for
|
||||
example) `TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models
|
||||
<https://arxiv.org/abs/2109.10282>`__ by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang,
|
||||
Zhoujun Li, Furu Wei.
|
||||
|
||||
An example of how to use a :class:`~transformers.VisionEncoderDecoderModel` for inference can be seen in :doc:`TrOCR
|
||||
<trocr>`.
|
||||
|
||||
|
||||
VisionEncoderDecoderConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.VisionEncoderDecoderConfig
|
||||
:members:
|
||||
|
||||
|
||||
VisionEncoderDecoderModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.VisionEncoderDecoderModel
|
||||
:members: forward, from_encoder_decoder_pretrained
|
@ -271,6 +271,12 @@ _import_structure = {
|
||||
"TransfoXLCorpus",
|
||||
"TransfoXLTokenizer",
|
||||
],
|
||||
"models.trocr": [
|
||||
"TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"TrOCRConfig",
|
||||
"TrOCRProcessor",
|
||||
],
|
||||
"models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
|
||||
"models.visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"],
|
||||
"models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
|
||||
"models.wav2vec2": [
|
||||
@ -1173,6 +1179,10 @@ if is_torch_available():
|
||||
"load_tf_weights_in_transfo_xl",
|
||||
]
|
||||
)
|
||||
_import_structure["models.trocr"].extend(
|
||||
["TROCR_PRETRAINED_MODEL_ARCHIVE_LIST", "TrOCRForCausalLM", "TrOCRPreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.vision_encoder_decoder"].extend(["VisionEncoderDecoderModel"])
|
||||
_import_structure["models.visual_bert"].extend(
|
||||
[
|
||||
"VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -2094,6 +2104,8 @@ if TYPE_CHECKING:
|
||||
TransfoXLCorpus,
|
||||
TransfoXLTokenizer,
|
||||
)
|
||||
from .models.trocr import TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP, TrOCRConfig, TrOCRProcessor
|
||||
from .models.vision_encoder_decoder import VisionEncoderDecoderConfig
|
||||
from .models.visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
|
||||
from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
from .models.wav2vec2 import (
|
||||
@ -2853,7 +2865,9 @@ if TYPE_CHECKING:
|
||||
TransfoXLPreTrainedModel,
|
||||
load_tf_weights_in_transfo_xl,
|
||||
)
|
||||
from .models.visual_bert import ( # load_tf_weights_in_visual_bert,
|
||||
from .models.trocr import TROCR_PRETRAINED_MODEL_ARCHIVE_LIST, TrOCRForCausalLM, TrOCRPreTrainedModel
|
||||
from .models.vision_encoder_decoder import VisionEncoderDecoderModel
|
||||
from .models.visual_bert import (
|
||||
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
VisualBertForMultipleChoice,
|
||||
VisualBertForPreTraining,
|
||||
|
@ -86,6 +86,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
Whether the model is used as an encoder/decoder or not.
|
||||
is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether the model is used as decoder or not (in which case it's used as an encoder).
|
||||
cross_attention_hidden_size (:obj:`bool`, `optional`):
|
||||
The hidden size of the cross-attention layer in case the model is used as a decoder in an encoder-decoder
|
||||
setting and the cross-attention hidden dimension differs from `self.config.hidden_size`.
|
||||
add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models
|
||||
that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which
|
||||
@ -249,6 +252,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
||||
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
||||
self.is_decoder = kwargs.pop("is_decoder", False)
|
||||
self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None)
|
||||
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
|
||||
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
|
||||
|
||||
|
@ -26,6 +26,8 @@ from ...file_utils import CONFIG_NAME
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
|
||||
("trocr", "TrOCRConfig"),
|
||||
("fnet", "FNetConfig"),
|
||||
("gptj", "GPTJConfig"),
|
||||
("layoutlmv2", "LayoutLMv2Config"),
|
||||
@ -166,6 +168,8 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_NAMES_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add full (and cased) model names here
|
||||
("vision-encoder-decoder", "Vision Encoder decoder"),
|
||||
("trocr", "TrOCR"),
|
||||
("fnet", "FNet"),
|
||||
("gptj", "GPT-J"),
|
||||
("beit", "BeiT"),
|
||||
|
@ -188,6 +188,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
("trocr", "TrOCRForCausalLM"),
|
||||
("gptj", "GPTJForCausalLM"),
|
||||
("rembert", "RemBertForCausalLM"),
|
||||
("roformer", "RoFormerForCausalLM"),
|
||||
|
@ -64,6 +64,8 @@ class DeiTConfig(PretrainedConfig):
|
||||
The size (resolution) of each patch.
|
||||
num_channels (:obj:`int`, `optional`, defaults to :obj:`3`):
|
||||
The number of input channels.
|
||||
qkv_bias (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
|
||||
|
||||
Example::
|
||||
@ -96,6 +98,7 @@ class DeiTConfig(PretrainedConfig):
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
num_channels=3,
|
||||
qkv_bias=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -113,3 +116,4 @@ class DeiTConfig(PretrainedConfig):
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.qkv_bias = qkv_bias
|
||||
|
@ -135,9 +135,9 @@ class DeiTSelfAttention(nn.Module):
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
@ -476,6 +476,7 @@ class DeiTModel(DeiTPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
|
50
src/transformers/models/trocr/__init__.py
Normal file
50
src/transformers/models/trocr/__init__.py
Normal file
@ -0,0 +1,50 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule, is_sentencepiece_available, is_speech_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_trocr": [
|
||||
"TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"TrOCRConfig",
|
||||
],
|
||||
"processing_trocr": ["TrOCRProcessor"],
|
||||
}
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_trocr"] = [
|
||||
"TROCR_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TrOCRForCausalLM",
|
||||
"TrOCRPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_trocr import TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP, TrOCRConfig
|
||||
from .processing_trocr import TrOCRProcessor
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_trocr import TROCR_PRETRAINED_MODEL_ARCHIVE_LIST, TrOCRForCausalLM, TrOCRPreTrainedModel
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
150
src/transformers/models/trocr/configuration_trocr.py
Normal file
150
src/transformers/models/trocr/configuration_trocr.py
Normal file
@ -0,0 +1,150 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" TrOCR model configuration """
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"microsoft/trocr-base": "https://huggingface.co/microsoft/trocr-base/resolve/main/config.json",
|
||||
# See all TrOCR models at https://huggingface.co/models?filter=trocr
|
||||
}
|
||||
|
||||
|
||||
class TrOCRConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.TrOCRForCausalLM`. It is used
|
||||
to instantiate an TrOCR model according to the specified arguments, defining the model architecture. Instantiating
|
||||
a configuration with the defaults will yield a similar configuration to that of the TrOCR `microsoft/trocr-base
|
||||
<https://huggingface.co/microsoft/trocr-base>`__ architecture.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 50265):
|
||||
Vocabulary size of the TrOCR model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.TrOCRForCausalLM`.
|
||||
d_model (:obj:`int`, `optional`, defaults to 1024):
|
||||
Dimensionality of the layers and the pooler layer.
|
||||
decoder_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of decoder layers.
|
||||
decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the pooler. If string, :obj:`"gelu"`,
|
||||
:obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
dropout (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, and pooler.
|
||||
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
classifier_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for classifier.
|
||||
init_std (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
|
||||
https://arxiv.org/abs/1909.11556>`__ for more details.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to scale the word embeddings by sqrt(d_model).
|
||||
use_learned_position_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use learned position embeddings. If not, sinusoidal position embeddings will be used.
|
||||
layernorm_embedding (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use a layernorm after the word + position embeddings.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import TrOCRForCausalLM, TrOCRConfig
|
||||
|
||||
>>> # Initializing a TrOCR-base style configuration
|
||||
>>> configuration = TrOCRConfig()
|
||||
|
||||
>>> # Initializing a model from the TrOCR-base style configuration
|
||||
>>> model = TrOCRForCausalLM(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
"""
|
||||
model_type = "trocr"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"num_attention_heads": "decoder_attention_heads",
|
||||
"hidden_size": "d_model",
|
||||
"num_hidden_layers": "decoder_layers",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50265,
|
||||
d_model=1024,
|
||||
decoder_layers=12,
|
||||
decoder_attention_heads=16,
|
||||
decoder_ffn_dim=4096,
|
||||
activation_function="gelu",
|
||||
max_position_embeddings=512,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.0,
|
||||
activation_dropout=0.0,
|
||||
decoder_start_token_id=2,
|
||||
classifier_dropout=0.0,
|
||||
init_std=0.02,
|
||||
decoder_layerdrop=0.0,
|
||||
use_cache=False,
|
||||
scale_embedding=False,
|
||||
use_learned_position_embeddings=True,
|
||||
layernorm_embedding=True,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
**kwargs
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model
|
||||
self.decoder_layers = decoder_layers
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.activation_function = activation_function
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.init_std = init_std
|
||||
self.decoder_layerdrop = decoder_layerdrop
|
||||
self.use_cache = use_cache
|
||||
self.scale_embedding = scale_embedding
|
||||
self.use_learned_position_embeddings = use_learned_position_embeddings
|
||||
self.layernorm_embedding = layernorm_embedding
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
**kwargs,
|
||||
)
|
969
src/transformers/models/trocr/modeling_trocr.py
Normal file
969
src/transformers/models/trocr/modeling_trocr.py
Normal file
@ -0,0 +1,969 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch TrOCR decoder model (based on RoBERTa). """
|
||||
|
||||
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_start_docstrings, replace_return_docstrings
|
||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_trocr import TrOCRConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "TrOCRConfig"
|
||||
_TOKENIZER_FOR_DOC = "TrOCRTokenizer"
|
||||
_CHECKPOINT_FOR_DOC = "microsoft/trocr-base-handwritten"
|
||||
|
||||
|
||||
TROCR_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"microsoft/trocr-base-handwritten",
|
||||
# See all TrOCR models at https://huggingface.co/models?filter=trocr
|
||||
]
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
|
||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR
|
||||
class TrOCRLearnedPositionalEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
# TrOCR is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
# and adjust num_embeddings appropriately. Other models don't have this hack
|
||||
self.offset = 2
|
||||
super().__init__(num_embeddings + self.offset, embedding_dim)
|
||||
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
bsz, seq_len = input_ids_shape[:2]
|
||||
positions = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
||||
)
|
||||
return super().forward(positions + self.offset)
|
||||
|
||||
|
||||
class TrOCRSinusoidalPositionalEmbedding(nn.Module):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.offset = 2
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.weights = self.get_embedding(num_positions, embedding_dim, padding_idx)
|
||||
self.register_buffer("_float_tensor", torch.FloatTensor(1))
|
||||
|
||||
@staticmethod
|
||||
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
"""
|
||||
Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
|
||||
description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
||||
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
||||
if embedding_dim % 2 == 1:
|
||||
# zero pad
|
||||
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
||||
if padding_idx is not None:
|
||||
emb[padding_idx, :] = 0
|
||||
|
||||
return emb
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
|
||||
bsz, seq_len = input_ids.size()
|
||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
|
||||
input_ids.device
|
||||
)
|
||||
|
||||
# expand embeddings if needed
|
||||
max_pos = self.padding_idx + 1 + seq_len
|
||||
if self.weights is None or max_pos > self.weights.size(0):
|
||||
# recompute/expand embeddings if needed
|
||||
self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx)
|
||||
self.weights = self.weights.to(self._float_tensor)
|
||||
|
||||
x = self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
|
||||
|
||||
return x
|
||||
|
||||
def create_position_ids_from_input_ids(
|
||||
self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: Optional[int] = 0
|
||||
):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
|
||||
symbols are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
class TrOCRAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
kdim: int = None,
|
||||
vdim: int = None,
|
||||
dropout: float = 0.0,
|
||||
is_decoder: bool = False,
|
||||
bias: bool = True,
|
||||
is_cross_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.kdim = kdim if kdim is not None else embed_dim
|
||||
self.vdim = vdim if vdim is not None else embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
if not (self.head_dim * num_heads == self.embed_dim):
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
# get key, value proj
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||
)
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to be reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
class TrOCRDecoderLayer(nn.Module):
|
||||
def __init__(self, config: TrOCRConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.self_attn = TrOCRAttention(
|
||||
config,
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
if config.is_decoder:
|
||||
self.encoder_attn = TrOCRAttention(
|
||||
config,
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
kdim=config.cross_attention_hidden_size,
|
||||
vdim=config.cross_attention_hidden_size,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
is_cross_attention=True,
|
||||
)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
||||
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape :obj:`(seq_len, batch, embed_dim)`
|
||||
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
||||
:obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape :obj:`(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
||||
:obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
:obj:`(encoder_attention_heads,)`.
|
||||
cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||||
size `(decoder_attention_heads,)`.
|
||||
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
# Self Attention
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Cross-Attention Block
|
||||
cross_attn_present_key_value = None
|
||||
cross_attn_weights = None
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
residual = hidden_states
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||
|
||||
# add cross-attn to positions 3,4 of present_key_value tuple
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights, cross_attn_weights)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TrOCRPreTrainedModel(PreTrainedModel):
|
||||
config_class = TrOCRConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, TrOCRDecoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
TROCR_START_DOCSTRING = r"""
|
||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
||||
pruning heads etc.)
|
||||
|
||||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
||||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.TrOCRConfig`):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
:meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
class TrOCRDecoder(TrOCRPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`TrOCRDecoderLayer`
|
||||
|
||||
Args:
|
||||
config: TrOCRConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: TrOCRConfig):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
|
||||
if config.use_learned_position_embeddings:
|
||||
self.embed_positions = TrOCRLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
|
||||
else:
|
||||
self.embed_positions = TrOCRSinusoidalPositionalEmbedding(
|
||||
config.max_position_embeddings + self.padding_idx + 1,
|
||||
config.hidden_size,
|
||||
self.padding_idx,
|
||||
)
|
||||
|
||||
if config.layernorm_embedding:
|
||||
self.layernorm_embedding = nn.LayerNorm(config.hidden_size)
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
|
||||
self.layers = nn.ModuleList([TrOCRDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
|
||||
self.init_weights()
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(self.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||
provide it.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.TrOCRTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||
for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||
of the decoder.
|
||||
encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`):
|
||||
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
|
||||
selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||
on hidden heads. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
Tuple of :obj:`tuple(torch.FloatTensor)` of length :obj:`config.n_layers`, with each tuple having 2
|
||||
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
|
||||
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
||||
cross-attention blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential
|
||||
decoding.
|
||||
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last
|
||||
:obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of
|
||||
shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size,
|
||||
sequence_length)`.
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||
into associated vectors than the model's internal embedding lookup matrix.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
|
||||
for more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
if self.config.use_learned_position_embeddings:
|
||||
embed_pos = self.embed_positions(input_shape, past_key_values_length=past_key_values_length)
|
||||
else:
|
||||
embed_pos = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
|
||||
if self.layernorm_embedding is not None:
|
||||
hidden_states = self.layernorm_embedding(hidden_states)
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||
if attn_mask is not None:
|
||||
if attn_mask.size()[0] != (len(self.layers)):
|
||||
raise ValueError(
|
||||
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||
)
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, use_cache)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
head_mask[idx] if head_mask is not None else None,
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
cross_attn_layer_head_mask=(
|
||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attentions += (layer_outputs[2],)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The TrOCR Model with a language modeling head. Can be used for summarization.",
|
||||
TROCR_START_DOCSTRING,
|
||||
)
|
||||
class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
|
||||
"""
|
||||
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
||||
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.decoder = TrOCRDecoder(config)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.decoder(*args, **kwargs)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The TrOCR Decoder with a language modeling head. Can be used as the decoder part of :class:`~transformers.EncoderDecoderModel` and :class:`~transformers.VisionEncoderDecoder`.",
|
||||
TROCR_START_DOCSTRING,
|
||||
)
|
||||
class TrOCRForCausalLM(TrOCRPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
config.is_encoder_decoder = False
|
||||
self.model = TrOCRDecoderWrapper(config)
|
||||
|
||||
self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.decoder.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.decoder.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.output_projection
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.output_projection = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model.decoder = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||
provide it.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.TrOCRTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||
for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||
if the model is configured as a decoder.
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||||
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
Tuple of :obj:`tuple(torch.FloatTensor)` of length :obj:`config.n_layers`, with each tuple having 2
|
||||
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
|
||||
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two
|
||||
additional tensors are only required when the model is used as a decoder in a Sequence to Sequence
|
||||
model.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
||||
cross-attention blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential
|
||||
decoding.
|
||||
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
|
||||
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
|
||||
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
|
||||
config.vocab_size]``.
|
||||
use_cache (:obj:`bool`, `optional`):
|
||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||
decoding (see :obj:`past_key_values`).
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
|
||||
for more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import VisionEncoderDecoderModel, TrOCRForCausalLM, ViTModel, TrOCRConfig, ViTConfig
|
||||
|
||||
>>> encoder = ViTModel(ViTConfig())
|
||||
>>> decoder = TrOCRForCausalLM(TrOCRConfig())
|
||||
|
||||
# init vision2text model
|
||||
>>> model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||
"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model.decoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
logits = self.output_projection(outputs[0])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
141
src/transformers/models/trocr/processing_trocr.py
Normal file
141
src/transformers/models/trocr/processing_trocr.py
Normal file
@ -0,0 +1,141 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for TrOCR.
|
||||
"""
|
||||
from contextlib import contextmanager
|
||||
|
||||
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
||||
from transformers.models.roberta.tokenization_roberta import RobertaTokenizer
|
||||
from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
||||
|
||||
from ..auto.feature_extraction_auto import AutoFeatureExtractor
|
||||
|
||||
|
||||
class TrOCRProcessor:
|
||||
r"""
|
||||
Constructs a TrOCR processor which wraps a vision feature extractor and a TrOCR tokenizer into a single processor.
|
||||
|
||||
:class:`~transformers.TrOCRProcessor` offers all the functionalities of :class:`~transformers.AutoFeatureExtractor`
|
||||
and :class:`~transformers.RobertaTokenizer`. See the :meth:`~transformers.TrOCRProcessor.__call__` and
|
||||
:meth:`~transformers.TrOCRProcessor.decode` for more information.
|
||||
|
||||
Args:
|
||||
feature_extractor (:class:`~transformers.AutoFeatureExtractor`):
|
||||
An instance of :class:`~transformers.AutoFeatureExtractor`. The feature extractor is a required input.
|
||||
tokenizer (:class:`~transformers.RobertaTokenizer`):
|
||||
An instance of :class:`~transformers.RobertaTokenizer`. The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
def __init__(self, feature_extractor, tokenizer):
|
||||
if not isinstance(feature_extractor, FeatureExtractionMixin):
|
||||
raise ValueError(
|
||||
f"`feature_extractor` has to be of type {FeatureExtractionMixin.__class__}, but is {type(feature_extractor)}"
|
||||
)
|
||||
if not (isinstance(tokenizer, RobertaTokenizer) or (isinstance(tokenizer, RobertaTokenizerFast))):
|
||||
raise ValueError(
|
||||
f"`tokenizer` has to be of type {RobertaTokenizer.__class__} or {RobertaTokenizerFast.__class__}, but is {type(tokenizer)}"
|
||||
)
|
||||
|
||||
self.feature_extractor = feature_extractor
|
||||
self.tokenizer = tokenizer
|
||||
self.current_processor = self.feature_extractor
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
"""
|
||||
Save a TrOCR feature extractor object and TrOCR tokenizer object to the directory ``save_directory``, so that
|
||||
it can be re-loaded using the :func:`~transformers.TrOCRProcessor.from_pretrained` class method.
|
||||
|
||||
.. note::
|
||||
|
||||
This class method is simply calling :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` and
|
||||
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the
|
||||
docstrings of the methods above for more information.
|
||||
|
||||
Args:
|
||||
save_directory (:obj:`str` or :obj:`os.PathLike`):
|
||||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
||||
be created if it does not exist).
|
||||
"""
|
||||
|
||||
self.feature_extractor.save_pretrained(save_directory)
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a :class:`~transformers.TrOCRProcessor` from a pretrained TrOCR processor.
|
||||
|
||||
.. note::
|
||||
|
||||
This class method is simply calling AutoFeatureExtractor's
|
||||
:meth:`~transformers.PreTrainedFeatureExtractor.from_pretrained` and TrOCRTokenizer's
|
||||
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. Please refer to the
|
||||
docstrings of the methods above for more information.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on
|
||||
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
|
||||
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing a feature extractor file saved using the
|
||||
:meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g.,
|
||||
``./my_model_directory/``.
|
||||
- a path or url to a saved feature extractor JSON `file`, e.g.,
|
||||
``./my_model_directory/preprocessor_config.json``.
|
||||
**kwargs
|
||||
Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and
|
||||
:class:`~transformers.PreTrainedTokenizer`
|
||||
"""
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
When used in normal mode, this method forwards all its arguments to AutoFeatureExtractor's
|
||||
:meth:`~transformers.AutoFeatureExtractor.__call__` and returns its output. If used in the context
|
||||
:meth:`~transformers.TrOCRProcessor.as_target_processor` this method forwards all its arguments to
|
||||
TrOCRTokenizer's :meth:`~transformers.TrOCRTokenizer.__call__`. Please refer to the doctsring of the above two
|
||||
methods for more information.
|
||||
"""
|
||||
return self.current_processor(*args, **kwargs)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to TrOCRTokenizer's
|
||||
:meth:`~transformers.PreTrainedTokenizer.batch_decode`. Please refer to the docstring of this method for more
|
||||
information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to TrOCRTokenizer's :meth:`~transformers.PreTrainedTokenizer.decode`.
|
||||
Please refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
def as_target_processor(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning TrOCR.
|
||||
"""
|
||||
self.current_processor = self.tokenizer
|
||||
yield
|
||||
self.current_processor = self.feature_extractor
|
40
src/transformers/models/vision_encoder_decoder/__init__.py
Normal file
40
src/transformers/models/vision_encoder_decoder/__init__.py
Normal file
@ -0,0 +1,40 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
@ -0,0 +1,120 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VisionEncoderDecoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
:class:`~transformers.VisionEncoderDecoderConfig` is the configuration class to store the configuration of a
|
||||
:class:`~transformers.VisionEncoderDecoderModel`. It is used to instantiate an Encoder Decoder model according to
|
||||
the specified arguments, defining the encoder and decoder configs.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
Args:
|
||||
kwargs (`optional`):
|
||||
Dictionary of keyword arguments. Notably:
|
||||
|
||||
- **encoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration
|
||||
object that defines the encoder config.
|
||||
- **decoder** (:class:`~transformers.PretrainedConfig`, `optional`) -- An instance of a configuration
|
||||
object that defines the decoder config.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import BertConfig, ViTConfig, VisionEncoderDecoderConfig, VisionEncoderDecoderModel
|
||||
|
||||
>>> # Initializing a ViT & BERT style configuration
|
||||
>>> config_encoder = ViTConfig()
|
||||
>>> config_decoder = BertConfig()
|
||||
|
||||
>>> config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
|
||||
|
||||
>>> # Initializing a ViTBert model from a ViT & bert-base-uncased style configurations
|
||||
>>> model = VisionEncoderDecoderModel(config=config)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> config_encoder = model.config.encoder
|
||||
>>> config_decoder = model.config.decoder
|
||||
>>> # set decoder config to causal lm
|
||||
>>> config_decoder.is_decoder = True
|
||||
>>> config_decoder.add_cross_attention = True
|
||||
|
||||
>>> # Saving the model, including its configuration
|
||||
>>> model.save_pretrained('my-model')
|
||||
|
||||
>>> # loading model and config from pretrained folder
|
||||
>>> encoder_decoder_config = VisionEncoderDecoderConfig.from_pretrained('my-model')
|
||||
>>> model = VisionEncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config)
|
||||
"""
|
||||
model_type = "vision-encoder-decoder"
|
||||
is_composition = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if "encoder" not in kwargs or "decoder" not in kwargs:
|
||||
raise ValueError(
|
||||
f"A configuraton of type {self.model_type} cannot be instantiated because "
|
||||
f"not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
|
||||
)
|
||||
|
||||
encoder_config = kwargs.pop("encoder")
|
||||
encoder_model_type = encoder_config.pop("model_type")
|
||||
decoder_config = kwargs.pop("decoder")
|
||||
decoder_model_type = decoder_config.pop("model_type")
|
||||
|
||||
self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
|
||||
self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
|
||||
self.is_encoder_decoder = True
|
||||
|
||||
@classmethod
|
||||
def from_encoder_decoder_configs(
|
||||
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
|
||||
) -> PretrainedConfig:
|
||||
r"""
|
||||
Instantiate a :class:`~transformers.VisionEncoderDecoderConfig` (or a derived class) from a pre-trained encoder
|
||||
model configuration and decoder model configuration.
|
||||
|
||||
Returns:
|
||||
:class:`VisionEncoderDecoderConfig`: An instance of a configuration object
|
||||
"""
|
||||
logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
|
||||
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.
|
||||
|
||||
Returns:
|
||||
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["encoder"] = self.encoder.to_dict()
|
||||
output["decoder"] = self.decoder.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
@ -0,0 +1,238 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert TrOCR checkpoints from the unilm repository."""
|
||||
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import requests
|
||||
from transformers import (
|
||||
RobertaTokenizer,
|
||||
TrOCRConfig,
|
||||
TrOCRForCausalLM,
|
||||
TrOCRProcessor,
|
||||
VisionEncoderDecoderModel,
|
||||
ViTConfig,
|
||||
ViTFeatureExtractor,
|
||||
ViTModel,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(encoder_config, decoder_config):
|
||||
rename_keys = []
|
||||
for i in range(encoder_config.num_hidden_layers):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append(
|
||||
(f"encoder.deit.blocks.{i}.norm1.weight", f"encoder.encoder.layer.{i}.layernorm_before.weight")
|
||||
)
|
||||
rename_keys.append((f"encoder.deit.blocks.{i}.norm1.bias", f"encoder.encoder.layer.{i}.layernorm_before.bias"))
|
||||
rename_keys.append(
|
||||
(f"encoder.deit.blocks.{i}.attn.proj.weight", f"encoder.encoder.layer.{i}.attention.output.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"encoder.deit.blocks.{i}.attn.proj.bias", f"encoder.encoder.layer.{i}.attention.output.dense.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"encoder.deit.blocks.{i}.norm2.weight", f"encoder.encoder.layer.{i}.layernorm_after.weight")
|
||||
)
|
||||
rename_keys.append((f"encoder.deit.blocks.{i}.norm2.bias", f"encoder.encoder.layer.{i}.layernorm_after.bias"))
|
||||
rename_keys.append(
|
||||
(f"encoder.deit.blocks.{i}.mlp.fc1.weight", f"encoder.encoder.layer.{i}.intermediate.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"encoder.deit.blocks.{i}.mlp.fc1.bias", f"encoder.encoder.layer.{i}.intermediate.dense.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"encoder.deit.blocks.{i}.mlp.fc2.weight", f"encoder.encoder.layer.{i}.output.dense.weight")
|
||||
)
|
||||
rename_keys.append((f"encoder.deit.blocks.{i}.mlp.fc2.bias", f"encoder.encoder.layer.{i}.output.dense.bias"))
|
||||
|
||||
# cls token, position embeddings and patch embeddings of encoder
|
||||
rename_keys.extend(
|
||||
[
|
||||
("encoder.deit.cls_token", "encoder.embeddings.cls_token"),
|
||||
("encoder.deit.pos_embed", "encoder.embeddings.position_embeddings"),
|
||||
("encoder.deit.patch_embed.proj.weight", "encoder.embeddings.patch_embeddings.projection.weight"),
|
||||
("encoder.deit.patch_embed.proj.bias", "encoder.embeddings.patch_embeddings.projection.bias"),
|
||||
("encoder.deit.norm.weight", "encoder.layernorm.weight"),
|
||||
("encoder.deit.norm.bias", "encoder.layernorm.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
return rename_keys
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_q_k_v(state_dict, encoder_config):
|
||||
for i in range(encoder_config.num_hidden_layers):
|
||||
# queries, keys and values (only weights, no biases)
|
||||
in_proj_weight = state_dict.pop(f"encoder.deit.blocks.{i}.attn.qkv.weight")
|
||||
|
||||
state_dict[f"encoder.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
||||
: encoder_config.hidden_size, :
|
||||
]
|
||||
state_dict[f"encoder.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||
encoder_config.hidden_size : encoder_config.hidden_size * 2, :
|
||||
]
|
||||
state_dict[f"encoder.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
||||
-encoder_config.hidden_size :, :
|
||||
]
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
# We will verify our results on an image of the IAM Handwriting Database
|
||||
def prepare_img(checkpoint_url):
|
||||
if "handwritten" in checkpoint_url:
|
||||
url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02-00.jpg" # industry
|
||||
# url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02-12.jpg" # have
|
||||
# url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02-10.jpg" # let
|
||||
# url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" #
|
||||
# url = "https://fki.tic.heia-fr.ch/static/img/a01-122.jpg"
|
||||
elif "printed" in checkpoint_url or "stage1" in checkpoint_url:
|
||||
url = "https://www.researchgate.net/profile/Dinh-Sang/publication/338099565/figure/fig8/AS:840413229350922@1577381536857/An-receipt-example-in-the-SROIE-2019-dataset_Q640.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_tr_ocr_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our VisionEncoderDecoderModel structure.
|
||||
"""
|
||||
# define encoder and decoder configs based on checkpoint_url
|
||||
encoder_config = ViTConfig(image_size=384, qkv_bias=False)
|
||||
decoder_config = TrOCRConfig()
|
||||
|
||||
# size of the architecture
|
||||
if "base" in checkpoint_url:
|
||||
decoder_config.encoder_hidden_size = 768
|
||||
elif "large" in checkpoint_url:
|
||||
# use ViT-large encoder
|
||||
encoder_config.hidden_size = 1024
|
||||
encoder_config.intermediate_size = 4096
|
||||
encoder_config.num_hidden_layers = 24
|
||||
encoder_config.num_attention_heads = 16
|
||||
decoder_config.encoder_hidden_size = 1024
|
||||
else:
|
||||
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
|
||||
|
||||
# the large-printed + stage1 checkpoints uses sinusoidal position embeddings, no layernorm afterwards
|
||||
if "large-printed" in checkpoint_url or "stage1" in checkpoint_url:
|
||||
decoder_config.tie_word_embeddings = False
|
||||
decoder_config.activation_function = "relu"
|
||||
decoder_config.max_position_embeddings = 1024
|
||||
decoder_config.scale_embedding = True
|
||||
decoder_config.use_learned_position_embeddings = False
|
||||
decoder_config.layernorm_embedding = False
|
||||
|
||||
# load HuggingFace model
|
||||
encoder = ViTModel(encoder_config, add_pooling_layer=False)
|
||||
decoder = TrOCRForCausalLM(decoder_config)
|
||||
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||
model.eval()
|
||||
|
||||
# load state_dict of original model, rename some keys
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["model"]
|
||||
|
||||
rename_keys = create_rename_keys(encoder_config, decoder_config)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
read_in_q_k_v(state_dict, encoder_config)
|
||||
|
||||
# remove parameters we don't need
|
||||
del state_dict["encoder.deit.head.weight"]
|
||||
del state_dict["encoder.deit.head.bias"]
|
||||
del state_dict["decoder.version"]
|
||||
|
||||
# add prefix to decoder keys
|
||||
for key, val in state_dict.copy().items():
|
||||
val = state_dict.pop(key)
|
||||
if key.startswith("decoder") and "output_projection" not in key:
|
||||
state_dict["decoder.model." + key] = val
|
||||
else:
|
||||
state_dict[key] = val
|
||||
|
||||
# load state dict
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# Check outputs on an image
|
||||
feature_extractor = ViTFeatureExtractor(size=encoder_config.image_size)
|
||||
tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
|
||||
processor = TrOCRProcessor(feature_extractor, tokenizer)
|
||||
|
||||
pixel_values = processor(images=prepare_img(checkpoint_url), return_tensors="pt").pixel_values
|
||||
|
||||
# verify logits
|
||||
decoder_input_ids = torch.tensor([[model.config.decoder.decoder_start_token_id]])
|
||||
outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
|
||||
logits = outputs.logits
|
||||
|
||||
expected_shape = torch.Size([1, 1, 50265])
|
||||
if "trocr-base-handwritten" in checkpoint_url:
|
||||
expected_slice = torch.tensor(
|
||||
[-1.4502, -4.6683, -0.5347, -2.9291, 9.1435, -3.0571, 8.9764, 1.7560, 8.7358, -1.5311]
|
||||
)
|
||||
elif "trocr-large-handwritten" in checkpoint_url:
|
||||
expected_slice = torch.tensor(
|
||||
[-2.6437, -1.3129, -2.2596, -5.3455, 6.3539, 1.7604, 5.4991, 1.4702, 5.6113, 2.0170]
|
||||
)
|
||||
elif "trocr-base-printed" in checkpoint_url:
|
||||
expected_slice = torch.tensor(
|
||||
[-5.6816, -5.8388, 1.1398, -6.9034, 6.8505, -2.4393, 1.2284, -1.0232, -1.9661, -3.9210]
|
||||
)
|
||||
elif "trocr-large-printed" in checkpoint_url:
|
||||
expected_slice = torch.tensor(
|
||||
[-6.0162, -7.0959, 4.4155, -5.1063, 7.0468, -3.1631, 2.6466, -0.3081, -0.8106, -1.7535]
|
||||
)
|
||||
|
||||
if "stage1" not in checkpoint_url:
|
||||
assert logits.shape == expected_shape, "Shape of logits not as expected"
|
||||
assert torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-3), "First elements of logits not as expected"
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving processor to {pytorch_dump_folder_path}")
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_url",
|
||||
default="https://layoutlm.blob.core.windows.net/trocr/model_zoo/fairseq/trocr-base-handwritten.pt",
|
||||
type=str,
|
||||
help="URL to the original PyTorch checkpoint (.pth file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tr_ocr_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
|
@ -0,0 +1,505 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Classes to support Vision-Encoder-Text-Decoder architectures """
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_outputs import Seq2SeqLMOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "VisionEncoderDecoderConfig"
|
||||
|
||||
VISION_ENCODER_DECODER_START_DOCSTRING = r"""
|
||||
This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model
|
||||
as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via
|
||||
:meth:`~transformers.AutoModel.from_pretrained` function and the decoder is loaded via
|
||||
:meth:`~transformers.AutoModelForCausalLM.from_pretrained` function. Cross-attention layers are automatically added
|
||||
to the decoder and should be fine-tuned on a downstream generative task, like image captioning.
|
||||
|
||||
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
|
||||
tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks
|
||||
<https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
|
||||
Zhou, Wei Li, Peter J. Liu.
|
||||
|
||||
Additionally, in `TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models
|
||||
<https://arxiv.org/abs/2109.10282>`__ it is shown how leveraging large pretrained vision models for optical
|
||||
character recognition (OCR) yields a significant performance improvement.
|
||||
|
||||
After such an Vision-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
|
||||
models (see the examples for more information).
|
||||
|
||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
||||
pruning heads etc.)
|
||||
|
||||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
||||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.VisionEncoderDecoderConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
||||
weights.
|
||||
"""
|
||||
|
||||
VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using a feature extractor (e.g. if you use ViT as the encoder,
|
||||
you should use :class:`~transformers.ViTFeatureExtractor`). See
|
||||
:meth:`transformers.ViTFeatureExtractor.__call__` for details.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
|
||||
If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
|
||||
:obj:`past_key_values`).
|
||||
|
||||
Provide for sequence to sequence training to the decoder. Indices can be obtained using
|
||||
:class:`~transformers.PreTrainedTokenizer`. See :meth:`transformers.PreTrainedTokenizer.encode` and
|
||||
:meth:`transformers.PreTrainedTokenizer.__call__` for details.
|
||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||
also be used by default.
|
||||
encoder_outputs (:obj:`tuple(torch.FloatTensor)`, `optional`):
|
||||
This tuple must consist of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
|
||||
:obj:`attentions`) :obj:`last_hidden_state` (:obj:`torch.FloatTensor` of shape :obj:`(batch_size,
|
||||
sequence_length, hidden_size)`) is a tensor of hidden-states at the output of the last layer of the
|
||||
encoder. Used in the cross-attention of the decoder.
|
||||
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||
vectors than the model's internal embedding lookup matrix.
|
||||
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`decoder_input_ids`
|
||||
indices into associated vectors than the model's internal embedding lookup matrix.
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the masked language modeling loss for the decoder. Indices should be in ``[-100, 0,
|
||||
..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||
use_cache (:obj:`bool`, `optional`):
|
||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||
decoding (see :obj:`past_key_values`).
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
If set to ``True``, the model will return a :class:`~transformers.file_utils.Seq2SeqLMOutput` instead of a
|
||||
plain tuple.
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
|
||||
|
||||
- Without a prefix which will be input as ``**encoder_kwargs`` for the encoder forward function.
|
||||
- With a `decoder_` prefix which will be input as ``**decoder_kwargs`` for the decoder forward function.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
|
||||
class VisionEncoderDecoderModel(PreTrainedModel):
|
||||
r"""
|
||||
:class:`~transformers.VisionEncoderDecoderModel` is a generic model class that will be instantiated as a
|
||||
transformer architecture with one of the base model classes of the library as encoder and another one as decoder
|
||||
when created with the :meth`~transformers.AutoModel.from_pretrained` class method for the encoder and
|
||||
:meth`~transformers.AutoModelForCausalLM.from_pretrained` class method for the decoder.
|
||||
"""
|
||||
config_class = VisionEncoderDecoderConfig
|
||||
base_model_prefix = "vision_encoder_decoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PretrainedConfig] = None,
|
||||
encoder: Optional[PreTrainedModel] = None,
|
||||
decoder: Optional[PreTrainedModel] = None,
|
||||
):
|
||||
if config is None and (encoder is None or decoder is None):
|
||||
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
|
||||
if config is None:
|
||||
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
|
||||
else:
|
||||
if not isinstance(config, self.config_class):
|
||||
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
|
||||
|
||||
# initialize with config
|
||||
# make sure input & output embeddings is not tied
|
||||
config.tie_word_embeddings = False
|
||||
super().__init__(config)
|
||||
|
||||
if encoder is None:
|
||||
encoder = AutoModel.from_config(config.encoder)
|
||||
|
||||
if decoder is None:
|
||||
decoder = AutoModelForCausalLM.from_config(config.decoder)
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
|
||||
logger.warning(
|
||||
f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}"
|
||||
)
|
||||
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
|
||||
logger.warning(
|
||||
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}"
|
||||
)
|
||||
|
||||
# make sure that the individual model's config refers to the shared config
|
||||
# so that the updates to the config will be synced
|
||||
self.encoder.config = self.config.encoder
|
||||
self.decoder.config = self.config.decoder
|
||||
|
||||
# encoder outputs might need to be projected to different dimension for decoder
|
||||
if (
|
||||
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
||||
and self.decoder.config.cross_attention_hidden_size is None
|
||||
):
|
||||
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
|
||||
|
||||
if self.encoder.get_output_embeddings() is not None:
|
||||
raise ValueError(
|
||||
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
|
||||
)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.decoder.get_output_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
return self.decoder.set_output_embeddings(new_embeddings)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
# At the moment fast initialization is not supported for composite models
|
||||
if kwargs.get("_fast_init", False):
|
||||
logger.warning(
|
||||
"Fast initialization is currently not supported for VisionEncoderDecoderModel. Falling back to slow intialization..."
|
||||
)
|
||||
kwargs["_fast_init"] = False
|
||||
return super().from_pretrained(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_encoder_decoder_pretrained(
|
||||
cls,
|
||||
encoder_pretrained_model_name_or_path: str = None,
|
||||
decoder_pretrained_model_name_or_path: str = None,
|
||||
*model_args,
|
||||
**kwargs
|
||||
) -> PreTrainedModel:
|
||||
r"""
|
||||
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
|
||||
checkpoints.
|
||||
|
||||
|
||||
The model is set in evaluation mode by default using :obj:`model.eval()` (Dropout modules are deactivated). To
|
||||
train the model, you need to first set it back in training mode with :obj:`model.train()`.
|
||||
|
||||
Params:
|
||||
encoder_pretrained_model_name_or_path (:obj: `str`, `optional`):
|
||||
Information necessary to initiate the image encoder. Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. An
|
||||
example is ``google/vit-base-patch16-224-in21k``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
|
||||
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
|
||||
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
|
||||
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
|
||||
decoder_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
|
||||
Information necessary to initiate the text decoder. Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
|
||||
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
|
||||
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
|
||||
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
|
||||
model_args (remaining positional arguments, `optional`):
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`).
|
||||
|
||||
- To update the encoder configuration, use the prefix `encoder_` for each configuration parameter.
|
||||
- To update the decoder configuration, use the prefix `decoder_` for each configuration parameter.
|
||||
- To update the parent model configuration, do not use a prefix for each configuration parameter.
|
||||
|
||||
Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import VisionEncoderDecoderModel
|
||||
>>> # initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized
|
||||
>>> model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained('google/vit-base-patch16-224-in21k', 'bert-base-uncased')
|
||||
>>> # saving model after fine-tuning
|
||||
>>> model.save_pretrained("./vit-bert")
|
||||
>>> # load fine-tuned model
|
||||
>>> model = VisionEncoderDecoderModel.from_pretrained("./vit-bert")
|
||||
|
||||
"""
|
||||
|
||||
kwargs_encoder = {
|
||||
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
|
||||
}
|
||||
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||
}
|
||||
|
||||
# remove encoder, decoder kwargs from kwargs
|
||||
for key in kwargs_encoder.keys():
|
||||
del kwargs["encoder_" + key]
|
||||
for key in kwargs_decoder.keys():
|
||||
del kwargs["decoder_" + key]
|
||||
|
||||
# Load and initialize the encoder and decoder
|
||||
# The distinction between encoder and decoder at the model level is made
|
||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||
encoder = kwargs_encoder.pop("model", None)
|
||||
if encoder is None:
|
||||
if encoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
f"No `encoder_model` is passed to kwargs: {kwargs_encoder}. "
|
||||
f"In this case make sure that `encoder_pretrained_model_name_or_path` defined"
|
||||
)
|
||||
|
||||
if "config" not in kwargs_encoder:
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||
|
||||
logger.info(
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||
"from a decoder model. Cross-attention and casual mask are disabled."
|
||||
)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.add_cross_attention = False
|
||||
|
||||
kwargs_encoder["config"] = encoder_config
|
||||
|
||||
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
||||
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
if decoder_pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
|
||||
)
|
||||
|
||||
if "config" not in kwargs_decoder:
|
||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||
logger.info(
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model."
|
||||
"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
|
||||
"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
|
||||
)
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
|
||||
kwargs_decoder["config"] = decoder_config
|
||||
|
||||
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
|
||||
logger.warning(
|
||||
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder."
|
||||
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
|
||||
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config`"
|
||||
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` "
|
||||
f"to `.from_encoder_decoder_pretrained(...)`"
|
||||
)
|
||||
|
||||
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
|
||||
# instantiate config with corresponding kwargs
|
||||
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
||||
|
||||
# make sure input & output embeddings is not tied
|
||||
config.tie_word_embeddings = False
|
||||
return cls(encoder=encoder, decoder=decoder, config=config)
|
||||
|
||||
@add_start_docstrings_to_model_forward(VISION_ENCODER_DECODER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None,
|
||||
decoder_inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
>>> import torch
|
||||
|
||||
>>> processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
|
||||
>>> model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
|
||||
|
||||
>>> # load image from the IAM dataset
|
||||
>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
|
||||
>>> # training
|
||||
>>> pixel_values = processor(image, return_tensors="pt").pixel_values # Batch size 1
|
||||
>>> decoder_input_ids = torch.tensor([[model.config.decoder.decoder_start_token_id]])
|
||||
>>> outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
>>> # inference (generation)
|
||||
>>> generated_ids = model.generate(pixel_values)
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
||||
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||
}
|
||||
|
||||
if encoder_outputs is None:
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
**kwargs_encoder,
|
||||
)
|
||||
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
|
||||
# optionally project encoder_hidden_states
|
||||
if (
|
||||
self.encoder.config.hidden_size != self.decoder.config.hidden_size
|
||||
and self.decoder.config.cross_attention_hidden_size is None
|
||||
):
|
||||
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||
|
||||
# else:
|
||||
encoder_attention_mask = None
|
||||
|
||||
# Decode
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
labels=labels,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
use_cache=use_cache,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=return_dict,
|
||||
**kwargs_decoder,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
loss=decoder_outputs.loss,
|
||||
logits=decoder_outputs.logits,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_hidden_states,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
):
|
||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||
input_dict = {
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"decoder_input_ids": decoder_inputs["input_ids"],
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": decoder_inputs["past_key_values"],
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
return input_dict
|
||||
|
||||
def resize_token_embeddings(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported."
|
||||
"Please use the respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
|
||||
)
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
# apply decoder cache reordering here
|
||||
return self.decoder._reorder_cache(past, beam_idx)
|
@ -63,6 +63,8 @@ class ViTConfig(PretrainedConfig):
|
||||
The size (resolution) of each patch.
|
||||
num_channels (:obj:`int`, `optional`, defaults to :obj:`3`):
|
||||
The number of input channels.
|
||||
qkv_bias (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
|
||||
|
||||
Example::
|
||||
@ -95,6 +97,7 @@ class ViTConfig(PretrainedConfig):
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
num_channels=3,
|
||||
qkv_bias=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -112,3 +115,4 @@ class ViTConfig(PretrainedConfig):
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.qkv_bias = qkv_bias
|
||||
|
@ -139,16 +139,19 @@ class FlaxViTSelfAttention(nn.Module):
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
use_bias=self.config.qkv_bias,
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
use_bias=self.config.qkv_bias,
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
use_bias=self.config.qkv_bias,
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
|
||||
|
@ -169,9 +169,9 @@ class ViTSelfAttention(nn.Module):
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
@ -505,6 +505,7 @@ class ViTModel(ViTPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
|
@ -3587,6 +3587,36 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_transfo_xl, ["torch"])
|
||||
|
||||
|
||||
TROCR_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class TrOCRForCausalLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class TrOCRPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class VisionEncoderDecoderModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
191
tests/test_modeling_trocr.py
Normal file
191
tests/test_modeling_trocr.py
Normal file
@ -0,0 +1,191 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch TrOCR model. """
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import TrOCRConfig
|
||||
from transformers.testing_utils import is_torch_available, require_torch, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.trocr.modeling_trocr import TrOCRDecoder, TrOCRForCausalLM
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrOCRStandaloneDecoderModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
vocab_size=99,
|
||||
batch_size=13,
|
||||
d_model=16,
|
||||
decoder_seq_length=7,
|
||||
is_training=True,
|
||||
is_decoder=True,
|
||||
use_attention_mask=True,
|
||||
use_cache=False,
|
||||
use_labels=True,
|
||||
decoder_start_token_id=2,
|
||||
decoder_ffn_dim=32,
|
||||
decoder_layers=4,
|
||||
decoder_attention_heads=4,
|
||||
max_position_embeddings=30,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.decoder_seq_length = decoder_seq_length
|
||||
# For common tests
|
||||
self.seq_length = self.decoder_seq_length
|
||||
self.is_training = is_training
|
||||
self.use_attention_mask = use_attention_mask
|
||||
self.use_labels = use_labels
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model
|
||||
self.hidden_size = d_model
|
||||
self.num_hidden_layers = decoder_layers
|
||||
self.decoder_layers = decoder_layers
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.num_attention_heads = decoder_attention_heads
|
||||
self.eos_token_id = eos_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
self.use_cache = use_cache
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.scope = None
|
||||
self.decoder_key_length = decoder_seq_length
|
||||
self.base_model_out_len = 2
|
||||
self.decoder_attention_idx = 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||
|
||||
attention_mask = None
|
||||
if self.use_attention_mask:
|
||||
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
|
||||
|
||||
lm_labels = None
|
||||
if self.use_labels:
|
||||
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||
|
||||
config = TrOCRConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.d_model,
|
||||
decoder_layers=self.decoder_layers,
|
||||
decoder_ffn_dim=self.decoder_ffn_dim,
|
||||
decoder_attention_heads=self.decoder_attention_heads,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
use_cache=self.use_cache,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
)
|
||||
|
||||
return (config, input_ids, attention_mask, lm_labels)
|
||||
|
||||
def create_and_check_decoder_model_past(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
lm_labels,
|
||||
):
|
||||
config.use_cache = True
|
||||
model = TrOCRDecoder(config=config).to(torch_device).eval()
|
||||
input_ids = input_ids[:2]
|
||||
|
||||
input_ids[input_ids == 0] += 1
|
||||
# first forward pass
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
outputs_use_cache_conf = model(input_ids)
|
||||
outputs_no_past = model(input_ids, use_cache=False)
|
||||
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
past_key_values = outputs["past_key_values"]
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((2, 1), config.vocab_size - 1) + 1
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask, lm_labels = config_and_inputs
|
||||
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrOCRStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TrOCRDecoder, TrOCRForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (TrOCRForCausalLM,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TrOCRStandaloneDecoderModelTester(self, is_training=False)
|
||||
self.config_tester = ConfigTester(self, config_class=TrOCRConfig)
|
||||
|
||||
# not implemented currently
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# trocr has no base model
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
# trocr has no base model
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_decoder_model_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
|
||||
|
||||
# decoder cannot keep gradients
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
return
|
658
tests/test_modeling_vision_encoder_decoder.py
Normal file
658
tests/test_modeling_vision_encoder_decoder.py
Normal file
@ -0,0 +1,658 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
|
||||
from .test_modeling_bert import BertModelTester
|
||||
from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from .test_modeling_deit import DeiTModelTester
|
||||
from .test_modeling_trocr import TrOCRStandaloneDecoderModelTester
|
||||
from .test_modeling_vit import ViTModelTester
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
BertLMHeadModel,
|
||||
DeiTModel,
|
||||
TrOCRForCausalLM,
|
||||
VisionEncoderDecoderConfig,
|
||||
VisionEncoderDecoderModel,
|
||||
ViTModel,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
from transformers.models.vit.modeling_vit import to_2tuple
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import TrOCRProcessor
|
||||
|
||||
|
||||
@require_torch
|
||||
class EncoderDecoderMixin:
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
pass
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pass
|
||||
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
pass
|
||||
|
||||
def check_encoder_decoder_model_from_pretrained_configs(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
):
|
||||
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
|
||||
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder_decoder_config)
|
||||
enc_dec_model.to(torch_device)
|
||||
enc_dec_model.eval()
|
||||
|
||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
|
||||
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
|
||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
encoder_outputs = BaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1])
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_from_pretrained(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
return_dict,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
|
||||
enc_dec_model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
|
||||
def check_save_and_load(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
enc_dec_model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
out_2 = outputs[0].cpu().numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
enc_dec_model.save_pretrained(tmpdirname)
|
||||
enc_dec_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname)
|
||||
enc_dec_model.to(torch_device)
|
||||
|
||||
after_outputs = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
out_1 = after_outputs[0].cpu().numpy()
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def check_save_and_load_encoder_decoder_model(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
enc_dec_model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
out_2 = outputs[0].cpu().numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
|
||||
enc_dec_model.encoder.save_pretrained(encoder_tmp_dirname)
|
||||
enc_dec_model.decoder.save_pretrained(decoder_tmp_dirname)
|
||||
VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
encoder_pretrained_model_name_or_path=encoder_tmp_dirname,
|
||||
decoder_pretrained_model_name_or_path=decoder_tmp_dirname,
|
||||
)
|
||||
|
||||
after_outputs = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
out_1 = after_outputs[0].cpu().numpy()
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
labels=None,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = to_2tuple(encoder_model.config.image_size)
|
||||
patch_size = to_2tuple(encoder_model.config.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_len = num_patches + 1
|
||||
self.assertEqual(encoder_attentions[0].shape[-3:], (config.num_attention_heads, seq_len, seq_len))
|
||||
|
||||
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
|
||||
num_decoder_layers = (
|
||||
decoder_config.num_decoder_layers
|
||||
if hasattr(decoder_config, "num_decoder_layers")
|
||||
else decoder_config.num_hidden_layers
|
||||
)
|
||||
self.assertEqual(len(decoder_attentions), num_decoder_layers)
|
||||
|
||||
self.assertEqual(
|
||||
decoder_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
cross_attentions = outputs_encoder_decoder["cross_attentions"]
|
||||
self.assertEqual(len(cross_attentions), num_decoder_layers)
|
||||
|
||||
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
|
||||
self.assertEqual(
|
||||
cross_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len),
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
|
||||
inputs = pixel_values
|
||||
|
||||
# Bert does not have a bos token id, so use pad_token_id instead
|
||||
generated_output = enc_dec_model.generate(
|
||||
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
|
||||
)
|
||||
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
|
||||
|
||||
def test_encoder_decoder_model(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_from_pretrained_configs(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_from_pretrained(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=False)
|
||||
|
||||
def test_encoder_decoder_model_from_pretrained_return_dict(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True)
|
||||
|
||||
def test_save_and_load_from_pretrained(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_save_and_load(**input_ids_dict)
|
||||
|
||||
def test_save_and_load_from_encoder_decoder_pretrained(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_save_and_load_encoder_decoder_model(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_output_attentions(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
|
||||
|
||||
def test_encoder_decoder_model_generate(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2, inputs = self.get_pretrained_model_and_inputs()
|
||||
model_2.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model_2(**inputs)
|
||||
out_2 = outputs[0].cpu().numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
model_2.save_pretrained(tmp_dirname)
|
||||
model_1 = VisionEncoderDecoderModel.from_pretrained(tmp_dirname)
|
||||
model_1.to(torch_device)
|
||||
|
||||
after_outputs = model_1(**inputs)
|
||||
out_1 = after_outputs[0].cpu().numpy()
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
|
||||
@require_torch
|
||||
class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"hf-internal-testing/tiny-random-deit", "hf-internal-testing/tiny-random-roberta"
|
||||
)
|
||||
batch_size = 13
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
batch_size,
|
||||
model.encoder.config.num_channels,
|
||||
model.encoder.config.image_size,
|
||||
model.encoder.config.image_size,
|
||||
]
|
||||
)
|
||||
# for DEiT, the sequence length is equal to the number of patches + 2 (for the [CLS] and distillation tokens)
|
||||
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 2
|
||||
attention_mask = random_attention_mask([batch_size, seq_len])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
return model, inputs
|
||||
|
||||
def check_encoder_decoder_model_output_attentions(
|
||||
self,
|
||||
config,
|
||||
attention_mask,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
labels=None,
|
||||
pixel_values=None,
|
||||
**kwargs
|
||||
):
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=True,
|
||||
)
|
||||
|
||||
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
|
||||
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
|
||||
|
||||
# in DEiT, the seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens)
|
||||
image_size = to_2tuple(encoder_model.config.image_size)
|
||||
patch_size = to_2tuple(encoder_model.config.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_len = num_patches + 2
|
||||
self.assertEqual(encoder_attentions[0].shape[-3:], (config.num_attention_heads, seq_len, seq_len))
|
||||
|
||||
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
|
||||
num_decoder_layers = (
|
||||
decoder_config.num_decoder_layers
|
||||
if hasattr(decoder_config, "num_decoder_layers")
|
||||
else decoder_config.num_hidden_layers
|
||||
)
|
||||
self.assertEqual(len(decoder_attentions), num_decoder_layers)
|
||||
|
||||
self.assertEqual(
|
||||
decoder_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
cross_attentions = outputs_encoder_decoder["cross_attentions"]
|
||||
self.assertEqual(len(cross_attentions), num_decoder_layers)
|
||||
|
||||
cross_attention_input_seq_len = decoder_input_ids.shape[-1]
|
||||
self.assertEqual(
|
||||
cross_attentions[0].shape[-3:],
|
||||
(decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len),
|
||||
)
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = DeiTModel(config).eval()
|
||||
decoder_model = BertLMHeadModel(decoder_config).eval()
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
bert_model_tester = BertModelTester(self)
|
||||
deit_model_tester = DeiTModelTester(self)
|
||||
encoder_config_and_inputs = deit_model_tester.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
|
||||
config, pixel_values, _ = encoder_config_and_inputs
|
||||
input_mask = None # TODO add once attention_mask is supported for vision models
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_token_type_ids,
|
||||
decoder_input_mask,
|
||||
decoder_sequence_labels,
|
||||
decoder_token_labels,
|
||||
decoder_choice_labels,
|
||||
encoder_attention_mask,
|
||||
_,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_token_type_ids": decoder_token_type_ids,
|
||||
"decoder_attention_mask": decoder_input_mask,
|
||||
"decoder_sequence_labels": decoder_sequence_labels,
|
||||
"decoder_token_labels": decoder_token_labels,
|
||||
"decoder_choice_labels": decoder_choice_labels,
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"hf-internal-testing/tiny-random-vit", "hf-internal-testing/tiny-bert"
|
||||
)
|
||||
batch_size = 13
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
batch_size,
|
||||
model.encoder.config.num_channels,
|
||||
model.encoder.config.image_size,
|
||||
model.encoder.config.image_size,
|
||||
]
|
||||
)
|
||||
# for ViT, the sequence length is equal to the number of patches + 1 (for the [CLS] token)
|
||||
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 1
|
||||
attention_mask = random_attention_mask([batch_size, seq_len])
|
||||
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
return model, inputs
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = ViTModel(config).eval()
|
||||
decoder_model = BertLMHeadModel(decoder_config).eval()
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
vit_model_tester = ViTModelTester(self)
|
||||
bert_model_tester = BertModelTester(self)
|
||||
encoder_config_and_inputs = vit_model_tester.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
|
||||
|
||||
config, pixel_values, _ = encoder_config_and_inputs
|
||||
input_mask = None # TODO add once attention_mask is supported for vision models
|
||||
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_token_type_ids,
|
||||
decoder_input_mask,
|
||||
decoder_sequence_labels,
|
||||
decoder_token_labels,
|
||||
decoder_choice_labels,
|
||||
encoder_attention_mask,
|
||||
_,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_token_type_ids": decoder_token_type_ids,
|
||||
"decoder_attention_mask": decoder_input_mask,
|
||||
"decoder_sequence_labels": decoder_sequence_labels,
|
||||
"decoder_token_labels": decoder_token_labels,
|
||||
"decoder_choice_labels": decoder_choice_labels,
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = ViTModel(config).eval()
|
||||
decoder_model = TrOCRForCausalLM(decoder_config).eval()
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = ViTModelTester(self, batch_size=13)
|
||||
model_tester_decoder = TrOCRStandaloneDecoderModelTester(
|
||||
self, batch_size=13, d_model=32, max_position_embeddings=512
|
||||
)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
|
||||
config, pixel_values, _ = encoder_config_and_inputs
|
||||
input_mask = None # TODO add once attention_mask is supported for vision models
|
||||
(decoder_config, decoder_input_ids, decoder_attention_mask, _) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
# disable cache for now
|
||||
decoder_config.use_cache = False
|
||||
return {
|
||||
"config": config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
|
||||
# there are no published pretrained TrOCR checkpoints for now
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
class TrOCRModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_processor(self):
|
||||
return TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") if is_vision_available() else None
|
||||
|
||||
@slow
|
||||
def test_inference_handwritten(self):
|
||||
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten").to(torch_device)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ocr", split="test")
|
||||
image = Image.open(ds[0]["file"]).convert("RGB")
|
||||
|
||||
processor = self.default_processor
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
|
||||
|
||||
# forward pass
|
||||
decoder_input_ids = torch.tensor([[model.config.decoder.decoder_start_token_id]]).to(torch_device)
|
||||
outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1, model.decoder.config.vocab_size))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[-1.4502, -4.6683, -0.5347, -2.9291, 9.1435, -3.0571, 8.9764, 1.7560, 8.7358, -1.5311]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_printed(self):
|
||||
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed").to(torch_device)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ocr", split="test")
|
||||
image = Image.open(ds[1]["file"]).convert("RGB")
|
||||
|
||||
processor = self.default_processor
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
|
||||
|
||||
# forward pass
|
||||
decoder_input_ids = torch.tensor([[model.config.decoder.decoder_start_token_id]]).to(torch_device)
|
||||
outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1, model.decoder.config.vocab_size))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[-5.6816, -5.8388, 1.1398, -6.9034, 6.8505, -2.4393, 1.2284, -1.0232, -1.9661, -3.9210]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-4))
|
Loading…
Reference in New Issue
Block a user