mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add LayoutLMv3 (#17060)
* Make forward pass work * More improvements * Remove unused imports * Remove timm dependency * Improve loss calculation of token classifier * Fix most tests * Add docs * Add model integration test * Make all tests pass * Add LayoutLMv3FeatureExtractor * Improve integration test + make fixup * Add example script * Fix style * Add LayoutLMv3Processor * Fix style * Add option to add visual labels * Make more tokenizer tests pass * Fix more tests * Make more tests pass * Fix bug and improve docs * Fix import of processors * Improve docstrings * Fix toctree and improve docs * Fix auto tokenizer * Move tests to model folder * Move tests to model folder * change default behavior add_prefix_space * add prefix space for fast * add_prefix_spcae set to True for Fast * no space before `unique_no_split` token * add test to hightligh special treatment of added tokens * fix `test_batch_encode_dynamic_overflowing` by building a long enough example * fix `test_full_tokenizer` with add_prefix_token * Fix tokenizer integration test * Make the code more readable * Add tests for LayoutLMv3Processor * Fix style * Add model to README and update init * Apply suggestions from code review * Replace asserts by value errors * Add suggestion by @ducviet00 * Add model to doc tests * Simplify script * Improve README * a step ahead to fix * Update pair_input_test * Make all tokenizer tests pass - phew * Make style * Add LayoutLMv3 to CI job * Fix auto mapping * Fix CI job name * Make all processor tests pass * Make tests of LayoutLMv2 and LayoutXLM consistent * Add copied from statements to fast tokenizer * Add copied from statements to slow tokenizer * Remove add_visual_labels attribute * Fix tests * Add link to notebooks * Improve docs of LayoutLMv3Processor * Fix reference to section Co-authored-by: SaulLu <lucilesaul.com@gmail.com> Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
13541b4aa2
commit
31ee80d556
@ -890,7 +890,7 @@ jobs:
|
||||
- run: make deps_table_check_updated
|
||||
- run: python utils/tests_fetcher.py --sanity_check
|
||||
|
||||
run_tests_layoutlmv2:
|
||||
run_tests_layoutlmv2_and_v3:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.7
|
||||
@ -921,7 +921,7 @@ jobs:
|
||||
path: ~/transformers/test_preparation.txt
|
||||
- run: |
|
||||
if [ -f test_list.txt ]; then
|
||||
python -m pytest -n 1 tests/models/*layoutlmv2* --dist=loadfile -s --make-reports=tests_layoutlmv2 --durations=100
|
||||
python -m pytest -n 1 tests/models/*layoutlmv* --dist=loadfile -s --make-reports=tests_layoutlmv2_and_v3 --durations=100
|
||||
fi
|
||||
- store_artifacts:
|
||||
path: ~/transformers/tests_output.txt
|
||||
@ -982,7 +982,7 @@ workflows:
|
||||
- run_tests_pipelines_tf
|
||||
- run_tests_onnxruntime
|
||||
- run_tests_hub
|
||||
- run_tests_layoutlmv2
|
||||
- run_tests_layoutlmv2_and_v3
|
||||
nightly:
|
||||
triggers:
|
||||
- schedule:
|
||||
|
@ -279,6 +279,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[ImageGPT](https://huggingface.co/docs/transformers/main/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
|
||||
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
|
||||
1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou.
|
||||
1. **[LayoutLMv3](https://huggingface.co/docs/transformers/main/model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei.
|
||||
1. **[LayoutXLM](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei.
|
||||
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
|
@ -258,6 +258,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
||||
1. **[ImageGPT](https://huggingface.co/docs/transformers/main/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
|
||||
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
|
||||
1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou.
|
||||
1. **[LayoutLMv3](https://huggingface.co/docs/transformers/main/model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei.
|
||||
1. **[LayoutXLM](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei.
|
||||
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
|
@ -282,6 +282,7 @@ conda install -c huggingface transformers
|
||||
1. **[ImageGPT](https://huggingface.co/docs/transformers/main/model_doc/imagegpt)** (来自 OpenAI) 伴随论文 [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) 由 Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever 发布。
|
||||
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (来自 Microsoft Research Asia) 伴随论文 [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) 由 Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou 发布。
|
||||
1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (来自 Microsoft Research Asia) 伴随论文 [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) 由 Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou 发布。
|
||||
1. **[LayoutLMv3](https://huggingface.co/docs/transformers/main/model_doc/layoutlmv3)** (来自 Microsoft Research Asia) 伴随论文 [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) 由 Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei 发布。
|
||||
1. **[LayoutXLM](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (来自 Microsoft Research Asia) 伴随论文 [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) 由 Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei 发布。
|
||||
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (来自 AllenAI) 伴随论文 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 由 Iz Beltagy, Matthew E. Peters, Arman Cohan 发布。
|
||||
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (来自 AllenAI) 伴随论文 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 由 Iz Beltagy, Matthew E. Peters, Arman Cohan 发布。
|
||||
|
@ -294,6 +294,7 @@ conda install -c huggingface transformers
|
||||
1. **[ImageGPT](https://huggingface.co/docs/transformers/main/model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
|
||||
1. **[LayoutLM](https://huggingface.co/docs/transformers/model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
|
||||
1. **[LayoutLMv2](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou.
|
||||
1. **[LayoutLMv3](https://huggingface.co/docs/transformers/main/model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei.
|
||||
1. **[LayoutXLM](https://huggingface.co/docs/transformers/model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei.
|
||||
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
|
@ -242,6 +242,8 @@
|
||||
title: LayoutLM
|
||||
- local: model_doc/layoutlmv2
|
||||
title: LayoutLMV2
|
||||
- local: model_doc/layoutlmv3
|
||||
title: LayoutLMV3
|
||||
- local: model_doc/layoutxlm
|
||||
title: LayoutXLM
|
||||
- local: model_doc/led
|
||||
|
@ -100,6 +100,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret
|
||||
1. **[ImageGPT](model_doc/imagegpt)** (from OpenAI) released with the paper [Generative Pretraining from Pixels](https://openai.com/blog/image-gpt/) by Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Heewoo Jun, David Luan, Ilya Sutskever.
|
||||
1. **[LayoutLM](model_doc/layoutlm)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
|
||||
1. **[LayoutLMv2](model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou.
|
||||
1. **[LayoutLMv3](model_doc/layoutlmv3)** (from Microsoft Research Asia) released with the paper [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei.
|
||||
1. **[LayoutXLM](model_doc/layoutlmv2)** (from Microsoft Research Asia) released with the paper [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei.
|
||||
1. **[LED](model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[Longformer](model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
@ -220,6 +221,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| ImageGPT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| LayoutLM | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| LayoutLMv2 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| LayoutLMv3 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| LUKE | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
85
docs/source/en/model_doc/layoutlmv3.mdx
Normal file
85
docs/source/en/model_doc/layoutlmv3.mdx
Normal file
@ -0,0 +1,85 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# LayoutLMv3
|
||||
|
||||
## Overview
|
||||
|
||||
The LayoutLMv3 model was proposed in [LayoutLMv3: Pre-training for Document AI with Unified Text and Image Masking](https://arxiv.org/abs/2204.08387) by Yupan Huang, Tengchao Lv, Lei Cui, Yutong Lu, Furu Wei.
|
||||
LayoutLMv3 simplifies [LayoutLMv2](layoutlmv2) by using patch embeddings (as in [ViT](vit)) instead of leveraging a CNN backbone, and pre-trains the model on 3 objectives: masked language modeling (MLM), masked image modeling (MIM)
|
||||
and word-patch alignment (WPA).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Self-supervised pre-training techniques have achieved remarkable progress in Document AI. Most multimodal pre-trained models use a masked language modeling objective to learn bidirectional representations on the text modality, but they differ in pre-training objectives for the image modality. This discrepancy adds difficulty to multimodal representation learning. In this paper, we propose LayoutLMv3 to pre-train multimodal Transformers for Document AI with unified text and image masking. Additionally, LayoutLMv3 is pre-trained with a word-patch alignment objective to learn cross-modal alignment by predicting whether the corresponding image patch of a text word is masked. The simple unified architecture and training objectives make LayoutLMv3 a general-purpose pre-trained model for both text-centric and image-centric Document AI tasks. Experimental results show that LayoutLMv3 achieves state-of-the-art performance not only in text-centric tasks, including form understanding, receipt understanding, and document visual question answering, but also in image-centric tasks such as document image classification and document layout analysis.*
|
||||
|
||||
Tips:
|
||||
|
||||
- In terms of data processing, LayoutLMv3 is identical to its predecessor [LayoutLMv2](layoutlmv2), except that:
|
||||
- images need to be resized and normalized with channels in regular RGB format. LayoutLMv2 on the other hand normalizes the images internally and expects the channels in BGR format.
|
||||
- text is tokenized using byte-pair encoding (BPE), as opposed to WordPiece.
|
||||
Due to these differences in data preprocessing, one can use [`LayoutLMv3Processor`] which internally combines a [`LayoutLMv3FeatureExtractor`] (for the image modality) and a [`LayoutLMv3Tokenizer`]/[`LayoutLMv3TokenizerFast`] (for the text modality) to prepare all data for the model.
|
||||
- Regarding usage of [`LayoutLMv3Processor`], we refer to the [usage guide](layoutlmv2#usage-LayoutLMv2Processor) of its predecessor.
|
||||
- Demo notebooks for LayoutLMv3 can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/LayoutLMv3).
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/layoutlmv3_architecture.png"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> LayoutLMv3 architecture. Taken from the <a href="https://arxiv.org/abs/2204.08387">original paper</a>. </small>
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/microsoft/unilm/tree/master/layoutlmv3).
|
||||
|
||||
|
||||
## LayoutLMv3Config
|
||||
|
||||
[[autodoc]] LayoutLMv3Config
|
||||
|
||||
## LayoutLMv3FeatureExtractor
|
||||
|
||||
[[autodoc]] LayoutLMv3FeatureExtractor
|
||||
- __call__
|
||||
|
||||
## LayoutLMv3Tokenizer
|
||||
|
||||
[[autodoc]] LayoutLMv3Tokenizer
|
||||
- __call__
|
||||
- save_vocabulary
|
||||
|
||||
## LayoutLMv3TokenizerFast
|
||||
|
||||
[[autodoc]] LayoutLMv3TokenizerFast
|
||||
- __call__
|
||||
|
||||
## LayoutLMv3Processor
|
||||
|
||||
[[autodoc]] LayoutLMv3Processor
|
||||
- __call__
|
||||
|
||||
## LayoutLMv3Model
|
||||
|
||||
[[autodoc]] LayoutLMv3Model
|
||||
- forward
|
||||
|
||||
## LayoutLMv3ForSequenceClassification
|
||||
|
||||
[[autodoc]] LayoutLMv3ForSequenceClassification
|
||||
- forward
|
||||
|
||||
## LayoutLMv3ForTokenClassification
|
||||
|
||||
[[autodoc]] LayoutLMv3ForTokenClassification
|
||||
- forward
|
||||
|
||||
## LayoutLMv3ForQuestionAnswering
|
||||
|
||||
[[autodoc]] LayoutLMv3ForQuestionAnswering
|
||||
- forward
|
69
examples/research_projects/layoutlmv3/README.md
Normal file
69
examples/research_projects/layoutlmv3/README.md
Normal file
@ -0,0 +1,69 @@
|
||||
<!---
|
||||
Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Token classification with LayoutLMv3 (PyTorch version)
|
||||
|
||||
This directory contains a script, `run_funsd_cord.py`, that can be used to fine-tune (or evaluate) LayoutLMv3 on form understanding datasets, such as [FUNSD](https://guillaumejaume.github.io/FUNSD/) and [CORD](https://github.com/clovaai/cord).
|
||||
|
||||
The script `run_funsd_cord.py` leverages the 🤗 Datasets library and the Trainer API. You can easily customize it to your needs.
|
||||
|
||||
## Fine-tuning on FUNSD
|
||||
|
||||
Fine-tuning LayoutLMv3 for token classification on [FUNSD](https://guillaumejaume.github.io/FUNSD/) can be done as follows:
|
||||
|
||||
```bash
|
||||
python run_funsd_cord.py \
|
||||
--model_name_or_path microsoft/layoutlmv3-base \
|
||||
--dataset_name funsd \
|
||||
--output_dir layoutlmv3-test \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--max_steps 1000 \
|
||||
--evaluation_strategy steps \
|
||||
--eval_steps 100 \
|
||||
--learning_rate 1e-5 \
|
||||
--load_best_model_at_end \
|
||||
--metric_for_best_model "eval_f1" \
|
||||
--push_to_hub \
|
||||
--push_to_hub°model_id layoutlmv3-finetuned-funsd
|
||||
```
|
||||
|
||||
👀 The resulting model can be found here: https://huggingface.co/nielsr/layoutlmv3-finetuned-funsd. By specifying the `push_to_hub` flag, the model gets uploaded automatically to the hub (regularly), together with a model card, which includes metrics such as precision, recall and F1. Note that you can easily update the model card, as it's just a README file of the respective repo on the hub.
|
||||
|
||||
There's also the "Training metrics" [tab](https://huggingface.co/nielsr/layoutlmv3-finetuned-funsd/tensorboard), which shows Tensorboard logs over the course of training. Pretty neat, huh?
|
||||
|
||||
## Fine-tuning on CORD
|
||||
|
||||
Fine-tuning LayoutLMv3 for token classification on [CORD](https://github.com/clovaai/cord) can be done as follows:
|
||||
|
||||
```bash
|
||||
python run_funsd_cord.py \
|
||||
--model_name_or_path microsoft/layoutlmv3-base \
|
||||
--dataset_name cord \
|
||||
--output_dir layoutlmv3-test \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--max_steps 1000 \
|
||||
--evaluation_strategy steps \
|
||||
--eval_steps 100 \
|
||||
--learning_rate 5e-5 \
|
||||
--load_best_model_at_end \
|
||||
--metric_for_best_model "eval_f1" \
|
||||
--push_to_hub \
|
||||
--push_to_hub°model_id layoutlmv3-finetuned-cord
|
||||
```
|
||||
|
||||
👀 The resulting model can be found here: https://huggingface.co/nielsr/layoutlmv3-finetuned-cord. Note that a model card gets generated automatically in case you specify the `push_to_hub` flag.
|
2
examples/research_projects/layoutlmv3/requirements.txt
Normal file
2
examples/research_projects/layoutlmv3/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
||||
datasets
|
||||
seqeval
|
533
examples/research_projects/layoutlmv3/run_funsd_cord.py
Normal file
533
examples/research_projects/layoutlmv3/run_funsd_cord.py
Normal file
@ -0,0 +1,533 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning LayoutLMv3 for token classification on FUNSD or CORD.
|
||||
"""
|
||||
# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as
|
||||
# comments.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import ClassLabel, load_dataset, load_metric
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForTokenClassification,
|
||||
AutoProcessor,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.data.data_collator import default_data_collator
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.19.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
default="microsoft/layoutlmv3-base",
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
processor_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Name or path to the processor files if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."})
|
||||
dataset_name: Optional[str] = field(
|
||||
default="nielsr/funsd-layoutlmv3",
|
||||
metadata={"help": "The name of the dataset to use (via the datasets library)."},
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
|
||||
)
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
|
||||
)
|
||||
text_column_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."}
|
||||
)
|
||||
label_column_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."}
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": (
|
||||
"The maximum total input sequence length after tokenization. If set, sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
label_all_tokens: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether to put the label for one word on all tokens of generated by that word or just on the "
|
||||
"one (in which case the other tokens will have a padding index)."
|
||||
)
|
||||
},
|
||||
)
|
||||
return_entity_level_metrics: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
self.task_name = self.task_name.lower()
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Get the datasets
|
||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if data_args.dataset_name == "funsd":
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = load_dataset(
|
||||
"nielsr/funsd-layoutlmv3",
|
||||
data_args.dataset_config_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
elif data_args.dataset_name == "cord":
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = load_dataset(
|
||||
"nielsr/cord-layoutlmv3",
|
||||
data_args.dataset_config_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
else:
|
||||
raise ValueError("This script only supports either FUNSD or CORD out-of-the-box.")
|
||||
|
||||
if training_args.do_train:
|
||||
column_names = dataset["train"].column_names
|
||||
features = dataset["train"].features
|
||||
else:
|
||||
column_names = dataset["test"].column_names
|
||||
features = dataset["test"].features
|
||||
|
||||
image_column_name = "image"
|
||||
text_column_name = "words" if "words" in column_names else "tokens"
|
||||
boxes_column_name = "bboxes"
|
||||
label_column_name = (
|
||||
f"{data_args.task_name}_tags" if f"{data_args.task_name}_tags" in column_names else column_names[1]
|
||||
)
|
||||
|
||||
remove_columns = column_names
|
||||
|
||||
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
|
||||
# unique labels.
|
||||
def get_label_list(labels):
|
||||
unique_labels = set()
|
||||
for label in labels:
|
||||
unique_labels = unique_labels | set(label)
|
||||
label_list = list(unique_labels)
|
||||
label_list.sort()
|
||||
return label_list
|
||||
|
||||
# If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
|
||||
# Otherwise, we have to get the list of labels manually.
|
||||
if isinstance(features[label_column_name].feature, ClassLabel):
|
||||
label_list = features[label_column_name].feature.names
|
||||
# No need to convert the labels since they are already ints.
|
||||
id2label = {k: v for k, v in enumerate(label_list)}
|
||||
label2id = {v: k for k, v in enumerate(label_list)}
|
||||
else:
|
||||
label_list = get_label_list(datasets["train"][label_column_name])
|
||||
id2label = {k: v for k, v in enumerate(label_list)}
|
||||
label2id = {v: k for k, v in enumerate(label_list)}
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and processor
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=data_args.task_name,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.processor_name if model_args.processor_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=True,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
add_prefix_space=True,
|
||||
apply_ocr=False,
|
||||
)
|
||||
|
||||
model = AutoModelForTokenClassification.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# Set the correspondences label/ID inside the model config
|
||||
model.config.label2id = label2id
|
||||
model.config.id2label = id2label
|
||||
|
||||
# Preprocessing the dataset
|
||||
# The processor does everything for us (prepare the image using LayoutLMv3FeatureExtractor
|
||||
# and prepare the words, boxes and word-level labels using LayoutLMv3TokenizerFast)
|
||||
def prepare_examples(examples):
|
||||
images = examples[image_column_name]
|
||||
words = examples[text_column_name]
|
||||
boxes = examples[boxes_column_name]
|
||||
word_labels = examples[label_column_name]
|
||||
|
||||
encoding = processor(
|
||||
images,
|
||||
words,
|
||||
boxes=boxes,
|
||||
word_labels=word_labels,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=data_args.max_seq_length,
|
||||
)
|
||||
|
||||
return encoding
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in dataset:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = dataset["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
with training_args.main_process_first(desc="train dataset map pre-processing"):
|
||||
train_dataset = train_dataset.map(
|
||||
prepare_examples,
|
||||
batched=True,
|
||||
remove_columns=remove_columns,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_eval:
|
||||
validation_name = "test"
|
||||
if validation_name not in dataset:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = dataset[validation_name]
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
with training_args.main_process_first(desc="validation dataset map pre-processing"):
|
||||
eval_dataset = eval_dataset.map(
|
||||
prepare_examples,
|
||||
batched=True,
|
||||
remove_columns=remove_columns,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_predict:
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
predict_dataset = datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
|
||||
predict_dataset = predict_dataset.select(range(max_predict_samples))
|
||||
with training_args.main_process_first(desc="prediction dataset map pre-processing"):
|
||||
predict_dataset = predict_dataset.map(
|
||||
prepare_examples,
|
||||
batched=True,
|
||||
remove_columns=remove_columns,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Metrics
|
||||
metric = load_metric("seqeval")
|
||||
|
||||
def compute_metrics(p):
|
||||
predictions, labels = p
|
||||
predictions = np.argmax(predictions, axis=2)
|
||||
|
||||
# Remove ignored index (special tokens)
|
||||
true_predictions = [
|
||||
[label_list[p] for (p, l) in zip(prediction, label) if l != -100]
|
||||
for prediction, label in zip(predictions, labels)
|
||||
]
|
||||
true_labels = [
|
||||
[label_list[l] for (p, l) in zip(prediction, label) if l != -100]
|
||||
for prediction, label in zip(predictions, labels)
|
||||
]
|
||||
|
||||
results = metric.compute(predictions=true_predictions, references=true_labels)
|
||||
if data_args.return_entity_level_metrics:
|
||||
# Unpack nested dictionaries
|
||||
final_results = {}
|
||||
for key, value in results.items():
|
||||
if isinstance(value, dict):
|
||||
for n, v in value.items():
|
||||
final_results[f"{key}_{n}"] = v
|
||||
else:
|
||||
final_results[key] = value
|
||||
return final_results
|
||||
else:
|
||||
return {
|
||||
"precision": results["overall_precision"],
|
||||
"recall": results["overall_recall"],
|
||||
"f1": results["overall_f1"],
|
||||
"accuracy": results["overall_accuracy"],
|
||||
}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
tokenizer=processor,
|
||||
data_collator=default_data_collator,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_result.metrics
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
|
||||
predictions = np.argmax(predictions, axis=2)
|
||||
|
||||
# Remove ignored index (special tokens)
|
||||
true_predictions = [
|
||||
[label_list[p] for (p, l) in zip(prediction, label) if l != -100]
|
||||
for prediction, label in zip(predictions, labels)
|
||||
]
|
||||
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
# Save predictions
|
||||
output_predictions_file = os.path.join(training_args.output_dir, "predictions.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_predictions_file, "w") as writer:
|
||||
for prediction in true_predictions:
|
||||
writer.write(" ".join(prediction) + "\n")
|
||||
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "token-classification"}
|
||||
if data_args.dataset_name is not None:
|
||||
kwargs["dataset_tags"] = data_args.dataset_name
|
||||
if data_args.dataset_config_name is not None:
|
||||
kwargs["dataset_args"] = data_args.dataset_config_name
|
||||
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
|
||||
else:
|
||||
kwargs["dataset"] = data_args.dataset_name
|
||||
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**kwargs)
|
||||
else:
|
||||
trainer.create_model_card(**kwargs)
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -226,6 +226,13 @@ _import_structure = {
|
||||
"LayoutLMv2Processor",
|
||||
"LayoutLMv2Tokenizer",
|
||||
],
|
||||
"models.layoutlmv3": [
|
||||
"LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"LayoutLMv3Config",
|
||||
"LayoutLMv3FeatureExtractor",
|
||||
"LayoutLMv3Processor",
|
||||
"LayoutLMv3Tokenizer",
|
||||
],
|
||||
"models.layoutxlm": ["LayoutXLMProcessor"],
|
||||
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
|
||||
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
|
||||
@ -504,6 +511,7 @@ else:
|
||||
_import_structure["models.herbert"].append("HerbertTokenizerFast")
|
||||
_import_structure["models.layoutlm"].append("LayoutLMTokenizerFast")
|
||||
_import_structure["models.layoutlmv2"].append("LayoutLMv2TokenizerFast")
|
||||
_import_structure["models.layoutlmv3"].append("LayoutLMv3TokenizerFast")
|
||||
_import_structure["models.layoutxlm"].append("LayoutXLMTokenizerFast")
|
||||
_import_structure["models.led"].append("LEDTokenizerFast")
|
||||
_import_structure["models.longformer"].append("LongformerTokenizerFast")
|
||||
@ -590,8 +598,7 @@ else:
|
||||
_import_structure["models.glpn"].append("GLPNFeatureExtractor")
|
||||
_import_structure["models.imagegpt"].append("ImageGPTFeatureExtractor")
|
||||
_import_structure["models.layoutlmv2"].append("LayoutLMv2FeatureExtractor")
|
||||
_import_structure["models.layoutlmv2"].append("LayoutLMv2Processor")
|
||||
_import_structure["models.layoutxlm"].append("LayoutXLMProcessor")
|
||||
_import_structure["models.layoutlmv3"].append("LayoutLMv3FeatureExtractor")
|
||||
_import_structure["models.maskformer"].append("MaskFormerFeatureExtractor")
|
||||
_import_structure["models.perceiver"].append("PerceiverFeatureExtractor")
|
||||
_import_structure["models.poolformer"].append("PoolFormerFeatureExtractor")
|
||||
@ -1199,6 +1206,16 @@ else:
|
||||
"LayoutLMv2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.layoutlmv3"].extend(
|
||||
[
|
||||
"LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"LayoutLMv3ForQuestionAnswering",
|
||||
"LayoutLMv3ForSequenceClassification",
|
||||
"LayoutLMv3ForTokenClassification",
|
||||
"LayoutLMv3Model",
|
||||
"LayoutLMv3PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.led"].extend(
|
||||
[
|
||||
"LED_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -2759,6 +2776,13 @@ if TYPE_CHECKING:
|
||||
LayoutLMv2Processor,
|
||||
LayoutLMv2Tokenizer,
|
||||
)
|
||||
from .models.layoutlmv3 import (
|
||||
LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
LayoutLMv3Config,
|
||||
LayoutLMv3FeatureExtractor,
|
||||
LayoutLMv3Processor,
|
||||
LayoutLMv3Tokenizer,
|
||||
)
|
||||
from .models.layoutxlm import LayoutXLMProcessor
|
||||
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
|
||||
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
|
||||
@ -3004,6 +3028,7 @@ if TYPE_CHECKING:
|
||||
from .models.herbert import HerbertTokenizerFast
|
||||
from .models.layoutlm import LayoutLMTokenizerFast
|
||||
from .models.layoutlmv2 import LayoutLMv2TokenizerFast
|
||||
from .models.layoutlmv3 import LayoutLMv3TokenizerFast
|
||||
from .models.layoutxlm import LayoutXLMTokenizerFast
|
||||
from .models.led import LEDTokenizerFast
|
||||
from .models.longformer import LongformerTokenizerFast
|
||||
@ -3069,8 +3094,8 @@ if TYPE_CHECKING:
|
||||
from .models.flava import FlavaFeatureExtractor, FlavaProcessor
|
||||
from .models.glpn import GLPNFeatureExtractor
|
||||
from .models.imagegpt import ImageGPTFeatureExtractor
|
||||
from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2Processor
|
||||
from .models.layoutxlm import LayoutXLMProcessor
|
||||
from .models.layoutlmv2 import LayoutLMv2FeatureExtractor
|
||||
from .models.layoutlmv3 import LayoutLMv3FeatureExtractor
|
||||
from .models.maskformer import MaskFormerFeatureExtractor
|
||||
from .models.perceiver import PerceiverFeatureExtractor
|
||||
from .models.poolformer import PoolFormerFeatureExtractor
|
||||
@ -3581,6 +3606,14 @@ if TYPE_CHECKING:
|
||||
LayoutLMv2Model,
|
||||
LayoutLMv2PreTrainedModel,
|
||||
)
|
||||
from .models.layoutlmv3 import (
|
||||
LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LayoutLMv3ForQuestionAnswering,
|
||||
LayoutLMv3ForSequenceClassification,
|
||||
LayoutLMv3ForTokenClassification,
|
||||
LayoutLMv3Model,
|
||||
LayoutLMv3PreTrainedModel,
|
||||
)
|
||||
from .models.led import (
|
||||
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LEDForConditionalGeneration,
|
||||
|
@ -1023,6 +1023,7 @@ SLOW_TO_FAST_CONVERTERS = {
|
||||
"HerbertTokenizer": HerbertConverter,
|
||||
"LayoutLMTokenizer": BertConverter,
|
||||
"LayoutLMv2Tokenizer": BertConverter,
|
||||
"LayoutLMv3Tokenizer": RobertaConverter,
|
||||
"LayoutXLMTokenizer": XLMRobertaConverter,
|
||||
"LongformerTokenizer": RobertaConverter,
|
||||
"LEDTokenizer": RobertaConverter,
|
||||
|
@ -69,6 +69,7 @@ from . import (
|
||||
imagegpt,
|
||||
layoutlm,
|
||||
layoutlmv2,
|
||||
layoutlmv3,
|
||||
layoutxlm,
|
||||
led,
|
||||
longformer,
|
||||
|
@ -73,6 +73,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("imagegpt", "ImageGPTConfig"),
|
||||
("layoutlm", "LayoutLMConfig"),
|
||||
("layoutlmv2", "LayoutLMv2Config"),
|
||||
("layoutlmv3", "LayoutLMv3Config"),
|
||||
("led", "LEDConfig"),
|
||||
("longformer", "LongformerConfig"),
|
||||
("luke", "LukeConfig"),
|
||||
@ -183,6 +184,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("layoutlm", "LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("layoutlmv3", "LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("led", "LED_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("longformer", "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
@ -294,6 +296,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("imagegpt", "ImageGPT"),
|
||||
("layoutlm", "LayoutLM"),
|
||||
("layoutlmv2", "LayoutLMv2"),
|
||||
("layoutlmv3", "LayoutLMv3"),
|
||||
("layoutxlm", "LayoutXLM"),
|
||||
("led", "LED"),
|
||||
("longformer", "Longformer"),
|
||||
|
@ -51,6 +51,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("glpn", "GLPNFeatureExtractor"),
|
||||
("hubert", "Wav2Vec2FeatureExtractor"),
|
||||
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
|
||||
("layoutlmv3", "LayoutLMv3FeatureExtractor"),
|
||||
("maskformer", "MaskFormerFeatureExtractor"),
|
||||
("perceiver", "PerceiverFeatureExtractor"),
|
||||
("poolformer", "PoolFormerFeatureExtractor"),
|
||||
|
@ -72,6 +72,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("imagegpt", "ImageGPTModel"),
|
||||
("layoutlm", "LayoutLMModel"),
|
||||
("layoutlmv2", "LayoutLMv2Model"),
|
||||
("layoutlmv3", "LayoutLMv3Model"),
|
||||
("led", "LEDModel"),
|
||||
("longformer", "LongformerModel"),
|
||||
("luke", "LukeModel"),
|
||||
@ -457,6 +458,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("ibert", "IBertForSequenceClassification"),
|
||||
("layoutlm", "LayoutLMForSequenceClassification"),
|
||||
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
|
||||
("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
|
||||
("led", "LEDForSequenceClassification"),
|
||||
("longformer", "LongformerForSequenceClassification"),
|
||||
("mbart", "MBartForSequenceClassification"),
|
||||
@ -505,6 +507,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("gptj", "GPTJForQuestionAnswering"),
|
||||
("ibert", "IBertForQuestionAnswering"),
|
||||
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
|
||||
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
|
||||
("led", "LEDForQuestionAnswering"),
|
||||
("longformer", "LongformerForQuestionAnswering"),
|
||||
("lxmert", "LxmertForQuestionAnswering"),
|
||||
@ -556,6 +559,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("ibert", "IBertForTokenClassification"),
|
||||
("layoutlm", "LayoutLMForTokenClassification"),
|
||||
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
||||
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
|
||||
("longformer", "LongformerForTokenClassification"),
|
||||
("megatron-bert", "MegatronBertForTokenClassification"),
|
||||
("mobilebert", "MobileBertForTokenClassification"),
|
||||
|
@ -40,6 +40,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("clip", "CLIPProcessor"),
|
||||
("flava", "FLAVAProcessor"),
|
||||
("layoutlmv2", "LayoutLMv2Processor"),
|
||||
("layoutlmv3", "LayoutLMv3Processor"),
|
||||
("layoutxlm", "LayoutXLMProcessor"),
|
||||
("sew", "Wav2Vec2Processor"),
|
||||
("sew-d", "Wav2Vec2Processor"),
|
||||
|
@ -131,6 +131,7 @@ else:
|
||||
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
|
||||
|
@ -29,6 +29,7 @@ from ...utils import (
|
||||
|
||||
_import_structure = {
|
||||
"configuration_layoutlmv2": ["LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMv2Config"],
|
||||
"processing_layoutlmv2": ["LayoutLMv2Processor"],
|
||||
"tokenization_layoutlmv2": ["LayoutLMv2Tokenizer"],
|
||||
}
|
||||
|
||||
@ -47,7 +48,6 @@ except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["feature_extraction_layoutlmv2"] = ["LayoutLMv2FeatureExtractor"]
|
||||
_import_structure["processing_layoutlmv2"] = ["LayoutLMv2Processor"]
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
@ -67,6 +67,7 @@ else:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_layoutlmv2 import LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMv2Config
|
||||
from .processing_layoutlmv2 import LayoutLMv2Processor
|
||||
from .tokenization_layoutlmv2 import LayoutLMv2Tokenizer
|
||||
|
||||
try:
|
||||
@ -84,7 +85,6 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
else:
|
||||
from .feature_extraction_layoutlmv2 import LayoutLMv2FeatureExtractor
|
||||
from .processing_layoutlmv2 import LayoutLMv2Processor
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
|
@ -144,3 +144,17 @@ class LayoutLMv2Processor(ProcessorMixin):
|
||||
)
|
||||
|
||||
return images_with_overflow
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
|
||||
to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
@ -499,15 +499,17 @@ class LayoutLMv2Tokenizer(PreTrainedTokenizer):
|
||||
is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
|
||||
|
||||
words = text if text_pair is None else text_pair
|
||||
assert boxes is not None, "You must provide corresponding bounding boxes"
|
||||
if boxes is None:
|
||||
raise ValueError("You must provide corresponding bounding boxes")
|
||||
if is_batched:
|
||||
assert len(words) == len(boxes), "You must provide words and boxes for an equal amount of examples"
|
||||
if len(words) != len(boxes):
|
||||
raise ValueError("You must provide words and boxes for an equal amount of examples")
|
||||
for words_example, boxes_example in zip(words, boxes):
|
||||
assert len(words_example) == len(
|
||||
boxes_example
|
||||
), "You must provide as many words as there are bounding boxes"
|
||||
if len(words_example) != len(boxes_example):
|
||||
raise ValueError("You must provide as many words as there are bounding boxes")
|
||||
else:
|
||||
assert len(words) == len(boxes), "You must provide as many words as there are bounding boxes"
|
||||
if len(words) != len(boxes):
|
||||
raise ValueError("You must provide as many words as there are bounding boxes")
|
||||
|
||||
if is_batched:
|
||||
if text_pair is not None and len(text) != len(text_pair):
|
||||
|
@ -260,15 +260,17 @@ class LayoutLMv2TokenizerFast(PreTrainedTokenizerFast):
|
||||
is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
|
||||
|
||||
words = text if text_pair is None else text_pair
|
||||
assert boxes is not None, "You must provide corresponding bounding boxes"
|
||||
if boxes is None:
|
||||
raise ValueError("You must provide corresponding bounding boxes")
|
||||
if is_batched:
|
||||
assert len(words) == len(boxes), "You must provide words and boxes for an equal amount of examples"
|
||||
if len(words) != len(boxes):
|
||||
raise ValueError("You must provide words and boxes for an equal amount of examples")
|
||||
for words_example, boxes_example in zip(words, boxes):
|
||||
assert len(words_example) == len(
|
||||
boxes_example
|
||||
), "You must provide as many words as there are bounding boxes"
|
||||
if len(words_example) != len(boxes_example):
|
||||
raise ValueError("You must provide as many words as there are bounding boxes")
|
||||
else:
|
||||
assert len(words) == len(boxes), "You must provide as many words as there are bounding boxes"
|
||||
if len(words) != len(boxes):
|
||||
raise ValueError("You must provide as many words as there are bounding boxes")
|
||||
|
||||
if is_batched:
|
||||
if text_pair is not None and len(text) != len(text_pair):
|
||||
|
107
src/transformers/models/layoutlmv3/__init__.py
Normal file
107
src/transformers/models/layoutlmv3/__init__.py
Normal file
@ -0,0 +1,107 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_layoutlmv3": ["LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMv3Config"],
|
||||
"processing_layoutlmv3": ["LayoutLMv3Processor"],
|
||||
"tokenization_layoutlmv3": ["LayoutLMv3Tokenizer"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_tokenizers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["tokenization_layoutlmv3_fast"] = ["LayoutLMv3TokenizerFast"]
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_layoutlmv3"] = [
|
||||
"LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"LayoutLMv3ForQuestionAnswering",
|
||||
"LayoutLMv3ForSequenceClassification",
|
||||
"LayoutLMv3ForTokenClassification",
|
||||
"LayoutLMv3Model",
|
||||
"LayoutLMv3PreTrainedModel",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["feature_extraction_layoutlmv3"] = ["LayoutLMv3FeatureExtractor"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_layoutlmv3 import LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMv3Config
|
||||
from .processing_layoutlmv3 import LayoutLMv3Processor
|
||||
from .tokenization_layoutlmv3 import LayoutLMv3Tokenizer
|
||||
|
||||
try:
|
||||
if not is_tokenizers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .tokenization_layoutlmv3_fast import LayoutLMv3TokenizerFast
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_layoutlmv3 import (
|
||||
LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LayoutLMv3ForQuestionAnswering,
|
||||
LayoutLMv3ForSequenceClassification,
|
||||
LayoutLMv3ForTokenClassification,
|
||||
LayoutLMv3Model,
|
||||
LayoutLMv3PreTrainedModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .feature_extraction_layoutlmv3 import LayoutLMv3FeatureExtractor
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
178
src/transformers/models/layoutlmv3/configuration_layoutlmv3.py
Normal file
178
src/transformers/models/layoutlmv3/configuration_layoutlmv3.py
Normal file
@ -0,0 +1,178 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" LayoutLMv3 model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LAYOUTLMV3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/resolve/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class LayoutLMv3Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`LayoutLMv3Model`]. It is used to instantiate an
|
||||
LayoutLMv3 model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the LayoutLMv3
|
||||
[microsoft/layoutlmv3-base](https://huggingface.co/microsoft/layoutlmv3-base) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 50265):
|
||||
Vocabulary size of the LayoutLMv3 model. Defines the number of different tokens that can be represented by
|
||||
the `inputs_ids` passed when calling [`LayoutLMv3Model`].
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimension of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size (`int`, *optional*, defaults to 2):
|
||||
The vocabulary size of the `token_type_ids` passed when calling [`LayoutLMv3Model`].
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
max_2d_position_embeddings (`int`, *optional*, defaults to 1024):
|
||||
The maximum value that the 2D position embedding might ever be used with. Typically set this to something
|
||||
large just in case (e.g., 1024).
|
||||
coordinate_size (`int`, *optional*, defaults to `128`):
|
||||
Dimension of the coordinate embeddings.
|
||||
shape_size (`int`, *optional*, defaults to `128`):
|
||||
Dimension of the width and height embeddings.
|
||||
has_relative_attention_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use a relative attention bias in the self-attention mechanism.
|
||||
rel_pos_bins (`int`, *optional*, defaults to 32):
|
||||
The number of relative position bins to be used in the self-attention mechanism.
|
||||
max_rel_pos (`int`, *optional*, defaults to 128):
|
||||
The maximum number of relative positions to be used in the self-attention mechanism.
|
||||
max_rel_2d_pos (`int`, *optional*, defaults to 256):
|
||||
The maximum number of relative 2D positions in the self-attention mechanism.
|
||||
rel_2d_pos_bins (`int`, *optional*, defaults to 64):
|
||||
The number of 2D relative position bins in the self-attention mechanism.
|
||||
has_spatial_attention_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use a spatial attention bias in the self-attention mechanism.
|
||||
visual_embed (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to add patch embeddings.
|
||||
input_size (`int`, *optional*, defaults to `224`):
|
||||
The size (resolution) of the images.
|
||||
num_channels (`int`, *optional*, defaults to `3`):
|
||||
The number of channels of the images.
|
||||
patch_size (`int`, *optional*, defaults to `16`)
|
||||
The size (resolution) of the patches.
|
||||
classifier_dropout (`float`, *optional*):
|
||||
The dropout ratio for the classification head.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import LayoutLMv3Model, LayoutLMv3Config
|
||||
|
||||
>>> # Initializing a LayoutLMv3 microsoft/layoutlmv3-base style configuration
|
||||
>>> configuration = LayoutLMv3Config()
|
||||
|
||||
>>> # Initializing a model from the microsoft/layoutlmv3-base style configuration
|
||||
>>> model = LayoutLMv3Model(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
model_type = "layoutlmv3"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50265,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
max_2d_position_embeddings=1024,
|
||||
coordinate_size=128,
|
||||
shape_size=128,
|
||||
has_relative_attention_bias=True,
|
||||
rel_pos_bins=32,
|
||||
max_rel_pos=128,
|
||||
rel_2d_pos_bins=64,
|
||||
max_rel_2d_pos=256,
|
||||
has_spatial_attention_bias=True,
|
||||
text_embed=True,
|
||||
visual_embed=True,
|
||||
input_size=224,
|
||||
num_channels=3,
|
||||
patch_size=16,
|
||||
classifier_dropout=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=hidden_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=hidden_act,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
type_vocab_size=type_vocab_size,
|
||||
initializer_range=initializer_range,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.max_2d_position_embeddings = max_2d_position_embeddings
|
||||
self.coordinate_size = coordinate_size
|
||||
self.shape_size = shape_size
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
self.rel_pos_bins = rel_pos_bins
|
||||
self.max_rel_pos = max_rel_pos
|
||||
self.has_spatial_attention_bias = has_spatial_attention_bias
|
||||
self.rel_2d_pos_bins = rel_2d_pos_bins
|
||||
self.max_rel_2d_pos = max_rel_2d_pos
|
||||
self.text_embed = text_embed
|
||||
self.visual_embed = visual_embed
|
||||
self.input_size = input_size
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.classifier_dropout = classifier_dropout
|
@ -0,0 +1,242 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Feature extractor class for LayoutLMv3.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageFeatureExtractionMixin, is_torch_tensor
|
||||
from ...utils import TensorType, is_pytesseract_available, logging, requires_backends
|
||||
|
||||
|
||||
# soft dependency
|
||||
if is_pytesseract_available():
|
||||
import pytesseract
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
ImageInput = Union[
|
||||
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
|
||||
]
|
||||
|
||||
|
||||
def normalize_box(box, width, height):
|
||||
return [
|
||||
int(1000 * (box[0] / width)),
|
||||
int(1000 * (box[1] / height)),
|
||||
int(1000 * (box[2] / width)),
|
||||
int(1000 * (box[3] / height)),
|
||||
]
|
||||
|
||||
|
||||
def apply_tesseract(image: Image.Image, lang: Optional[str]):
|
||||
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
|
||||
|
||||
# apply OCR
|
||||
data = pytesseract.image_to_data(image, lang=lang, output_type="dict")
|
||||
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
|
||||
|
||||
# filter empty words and corresponding coordinates
|
||||
irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
|
||||
words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
|
||||
left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
|
||||
top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
|
||||
width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
|
||||
height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
|
||||
|
||||
# turn coordinates into (left, top, left+width, top+height) format
|
||||
actual_boxes = []
|
||||
for x, y, w, h in zip(left, top, width, height):
|
||||
actual_box = [x, y, x + w, y + h]
|
||||
actual_boxes.append(actual_box)
|
||||
|
||||
image_width, image_height = image.size
|
||||
|
||||
# finally, normalize the bounding boxes
|
||||
normalized_boxes = []
|
||||
for box in actual_boxes:
|
||||
normalized_boxes.append(normalize_box(box, image_width, image_height))
|
||||
|
||||
assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
|
||||
|
||||
return words, normalized_boxes
|
||||
|
||||
|
||||
class LayoutLMv3FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
r"""
|
||||
Constructs a LayoutLMv3 feature extractor. This can be used to resize + normalize document images, as well as to
|
||||
apply OCR on them in order to get a list of words and normalized bounding boxes.
|
||||
|
||||
This feature extractor inherits from [`~feature_extraction_utils.PreTrainedFeatureExtractor`] which contains most
|
||||
of the main methods. Users should refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the input to a certain `size`.
|
||||
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
|
||||
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
||||
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize` is
|
||||
set to `True`.
|
||||
resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
|
||||
An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
|
||||
`PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
|
||||
if `do_resize` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to normalize the input with mean and standard deviation.
|
||||
image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
||||
The sequence of means for each channel, to be used when normalizing images.
|
||||
image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
|
||||
The sequence of standard deviations for each channel, to be used when normalizing images.
|
||||
apply_ocr (`bool`, *optional*, defaults to `True`):
|
||||
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
|
||||
ocr_lang (`Optional[str]`, *optional*):
|
||||
The language, specified by its ISO code, to be used by the Tesseract OCR engine. By default, English is
|
||||
used.
|
||||
|
||||
<Tip>
|
||||
|
||||
LayoutLMv3FeatureExtractor uses Google's Tesseract OCR engine under the hood.
|
||||
|
||||
</Tip>"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize=True,
|
||||
size=224,
|
||||
resample=Image.BILINEAR,
|
||||
do_normalize=True,
|
||||
image_mean=None,
|
||||
image_std=None,
|
||||
apply_ocr=True,
|
||||
ocr_lang=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||
self.apply_ocr = apply_ocr
|
||||
self.ocr_lang = ocr_lang
|
||||
|
||||
def __call__(
|
||||
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several image(s).
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||
number of channels, H and W are image height and width.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
||||
width).
|
||||
- **words** -- Optional words as identified by Tesseract OCR (only when [`LayoutLMv3FeatureExtractor`] was
|
||||
initialized with `apply_ocr` set to `True`).
|
||||
- **boxes** -- Optional bounding boxes as identified by Tesseract OCR, normalized based on the image size
|
||||
(only when [`LayoutLMv3FeatureExtractor`] was initialized with `apply_ocr` set to `True`).
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import LayoutLMv3FeatureExtractor
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
|
||||
|
||||
>>> # option 1: with apply_ocr=True (default)
|
||||
>>> feature_extractor = LayoutLMv3FeatureExtractor()
|
||||
>>> encoding = feature_extractor(image, return_tensors="pt")
|
||||
>>> print(encoding.keys())
|
||||
>>> # dict_keys(['pixel_values', 'words', 'boxes'])
|
||||
|
||||
>>> # option 2: with apply_ocr=False
|
||||
>>> feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
|
||||
>>> encoding = feature_extractor(image, return_tensors="pt")
|
||||
>>> print(encoding.keys())
|
||||
>>> # dict_keys(['pixel_values'])
|
||||
```"""
|
||||
|
||||
# Input type checking for clearer error
|
||||
valid_images = False
|
||||
|
||||
# Check that images has a valid type
|
||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
||||
valid_images = True
|
||||
elif isinstance(images, (list, tuple)):
|
||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
||||
valid_images = True
|
||||
|
||||
if not valid_images:
|
||||
raise ValueError(
|
||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
|
||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples), "
|
||||
f"but is of type {type(images)}."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(images, (list, tuple))
|
||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
||||
)
|
||||
|
||||
if not is_batched:
|
||||
images = [images]
|
||||
|
||||
# Tesseract OCR to get words + normalized bounding boxes
|
||||
if self.apply_ocr:
|
||||
requires_backends(self, "pytesseract")
|
||||
words_batch = []
|
||||
boxes_batch = []
|
||||
for image in images:
|
||||
words, boxes = apply_tesseract(self.to_pil_image(image), self.ocr_lang)
|
||||
words_batch.append(words)
|
||||
boxes_batch.append(boxes)
|
||||
|
||||
# transformations (resizing + normalization)
|
||||
if self.do_resize and self.size is not None:
|
||||
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
||||
if self.do_normalize:
|
||||
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
|
||||
|
||||
# return as BatchFeature
|
||||
data = {"pixel_values": images}
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
if self.apply_ocr:
|
||||
encoded_inputs["words"] = words_batch
|
||||
encoded_inputs["boxes"] = boxes_batch
|
||||
|
||||
return encoded_inputs
|
1309
src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
Normal file
1309
src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
Normal file
File diff suppressed because it is too large
Load Diff
158
src/transformers/models/layoutlmv3/processing_layoutlmv3.py
Normal file
158
src/transformers/models/layoutlmv3/processing_layoutlmv3.py
Normal file
@ -0,0 +1,158 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for LayoutLMv3.
|
||||
"""
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType
|
||||
|
||||
|
||||
class LayoutLMv3Processor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a LayoutLMv3 processor which combines a LayoutLMv3 feature extractor and a LayoutLMv3 tokenizer into a
|
||||
single processor.
|
||||
|
||||
[`LayoutLMv3Processor`] offers all the functionalities you need to prepare data for the model.
|
||||
|
||||
It first uses [`LayoutLMv3FeatureExtractor`] to resize and normalize document images, and optionally applies OCR to
|
||||
get words and normalized bounding boxes. These are then provided to [`LayoutLMv3Tokenizer`] or
|
||||
[`LayoutLMv3TokenizerFast`], which turns the words and bounding boxes into token-level `input_ids`,
|
||||
`attention_mask`, `token_type_ids`, `bbox`. Optionally, one can provide integer `word_labels`, which are turned
|
||||
into token-level `labels` for token classification tasks (such as FUNSD, CORD).
|
||||
|
||||
Args:
|
||||
feature_extractor (`LayoutLMv3FeatureExtractor`):
|
||||
An instance of [`LayoutLMv3FeatureExtractor`]. The feature extractor is a required input.
|
||||
tokenizer (`LayoutLMv3Tokenizer` or `LayoutLMv3TokenizerFast`):
|
||||
An instance of [`LayoutLMv3Tokenizer`] or [`LayoutLMv3TokenizerFast`]. The tokenizer is a required input.
|
||||
"""
|
||||
feature_extractor_class = "LayoutLMv3FeatureExtractor"
|
||||
tokenizer_class = ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
|
||||
boxes: Union[List[List[int]], List[List[List[int]]]] = None,
|
||||
word_labels: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
This method first forwards the `images` argument to [`~LayoutLMv3FeatureExtractor.__call__`]. In case
|
||||
[`LayoutLMv3FeatureExtractor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and
|
||||
bounding boxes along with the additional arguments to [`~LayoutLMv3Tokenizer.__call__`] and returns the output,
|
||||
together with resized and normalized `pixel_values`. In case [`LayoutLMv3FeatureExtractor`] was initialized
|
||||
with `apply_ocr` set to `False`, it passes the words (`text`/``text_pair`) and `boxes` specified by the user
|
||||
along with the additional arguments to [`~LayoutLMv3Tokenizer.__call__`] and returns the output, together with
|
||||
resized and normalized `pixel_values`.
|
||||
|
||||
Please refer to the docstring of the above two methods for more information.
|
||||
"""
|
||||
# verify input
|
||||
if self.feature_extractor.apply_ocr and (boxes is not None):
|
||||
raise ValueError(
|
||||
"You cannot provide bounding boxes "
|
||||
"if you initialized the feature extractor with apply_ocr set to True."
|
||||
)
|
||||
|
||||
if self.feature_extractor.apply_ocr and (word_labels is not None):
|
||||
raise ValueError(
|
||||
"You cannot provide word labels if you initialized the feature extractor with apply_ocr set to True."
|
||||
)
|
||||
|
||||
# first, apply the feature extractor
|
||||
features = self.feature_extractor(images=images, return_tensors=return_tensors)
|
||||
|
||||
# second, apply the tokenizer
|
||||
if text is not None and self.feature_extractor.apply_ocr and text_pair is None:
|
||||
if isinstance(text, str):
|
||||
text = [text] # add batch dimension (as the feature extractor always adds a batch dimension)
|
||||
text_pair = features["words"]
|
||||
|
||||
encoded_inputs = self.tokenizer(
|
||||
text=text if text is not None else features["words"],
|
||||
text_pair=text_pair if text_pair is not None else None,
|
||||
boxes=boxes if boxes is not None else features["boxes"],
|
||||
word_labels=word_labels,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# add pixel values
|
||||
images = features.pop("pixel_values")
|
||||
if return_overflowing_tokens is True:
|
||||
images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"])
|
||||
encoded_inputs["pixel_values"] = images
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def get_overflowing_images(self, images, overflow_to_sample_mapping):
|
||||
# in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image
|
||||
images_with_overflow = []
|
||||
for sample_idx in overflow_to_sample_mapping:
|
||||
images_with_overflow.append(images[sample_idx])
|
||||
|
||||
if len(images_with_overflow) != len(overflow_to_sample_mapping):
|
||||
raise ValueError(
|
||||
"Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got"
|
||||
f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}"
|
||||
)
|
||||
|
||||
return images_with_overflow
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
|
||||
to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
1478
src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py
Normal file
1478
src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,853 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fast tokenization class for LayoutLMv3. It overwrites 2 methods of the slow tokenizer class, namely _batch_encode_plus
|
||||
and _encode_plus, in which the Rust tokenizer is used.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from tokenizers import pre_tokenizers, processors
|
||||
|
||||
from ...tokenization_utils_base import (
|
||||
BatchEncoding,
|
||||
EncodedInput,
|
||||
PaddingStrategy,
|
||||
PreTokenizedInput,
|
||||
TensorType,
|
||||
TextInput,
|
||||
TextInputPair,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import add_end_docstrings, logging
|
||||
from .tokenization_layoutlmv3 import (
|
||||
LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING,
|
||||
LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
|
||||
LayoutLMv3Tokenizer,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/vocab.json",
|
||||
"microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/vocab.json",
|
||||
},
|
||||
"merges_file": {
|
||||
"microsoft/layoutlmv3-base": "https://huggingface.co/microsoft/layoutlmv3-base/raw/main/merges.txt",
|
||||
"microsoft/layoutlmv3-large": "https://huggingface.co/microsoft/layoutlmv3-large/raw/main/merges.txt",
|
||||
},
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"microsoft/layoutlmv3-base": 512,
|
||||
"microsoft/layoutlmv3-large": 512,
|
||||
}
|
||||
|
||||
|
||||
class LayoutLMv3TokenizerFast(PreTrainedTokenizerFast):
|
||||
r"""
|
||||
Construct a "fast" LayoutLMv3 tokenizer (backed by HuggingFace's *tokenizers* library). Based on BPE.
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
||||
refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
merges_file (`str`):
|
||||
Path to the merges file.
|
||||
errors (`str`, *optional*, defaults to `"replace"`):
|
||||
Paradigm to follow when decoding bytes to UTF-8. See
|
||||
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
||||
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
||||
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
||||
|
||||
<Tip>
|
||||
|
||||
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
||||
sequence. The token used is the `cls_token`.
|
||||
|
||||
</Tip>
|
||||
|
||||
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
||||
The end of sequence token.
|
||||
|
||||
<Tip>
|
||||
|
||||
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
||||
The token used is the `sep_token`.
|
||||
|
||||
</Tip>
|
||||
|
||||
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
||||
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||
token of a sequence built with special tokens.
|
||||
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
||||
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
||||
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
||||
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
||||
other word. (RoBERTa tokenizer detect beginning of words by the preceding space).
|
||||
trim_offsets (`bool`, *optional*, defaults to `True`):
|
||||
Whether the post processing step should trim offsets to avoid including whitespaces.
|
||||
cls_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
|
||||
The bounding box to use for the special [CLS] token.
|
||||
sep_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
|
||||
The bounding box to use for the special [SEP] token.
|
||||
pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`):
|
||||
The bounding box to use for the special [PAD] token.
|
||||
pad_token_label (`int`, *optional*, defaults to -100):
|
||||
The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's
|
||||
CrossEntropyLoss.
|
||||
only_label_first_subword (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to only label the first subword, in case word labels are provided.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = LayoutLMv3Tokenizer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file=None,
|
||||
merges_file=None,
|
||||
tokenizer_file=None,
|
||||
errors="replace",
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
sep_token="</s>",
|
||||
cls_token="<s>",
|
||||
unk_token="<unk>",
|
||||
pad_token="<pad>",
|
||||
mask_token="<mask>",
|
||||
add_prefix_space=True,
|
||||
trim_offsets=True,
|
||||
cls_token_box=[0, 0, 0, 0],
|
||||
sep_token_box=[0, 0, 0, 0],
|
||||
pad_token_box=[0, 0, 0, 0],
|
||||
pad_token_label=-100,
|
||||
only_label_first_subword=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
vocab_file,
|
||||
merges_file,
|
||||
tokenizer_file=tokenizer_file,
|
||||
errors=errors,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
sep_token=sep_token,
|
||||
cls_token=cls_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
mask_token=mask_token,
|
||||
add_prefix_space=add_prefix_space,
|
||||
trim_offsets=trim_offsets,
|
||||
cls_token_box=cls_token_box,
|
||||
sep_token_box=sep_token_box,
|
||||
pad_token_box=pad_token_box,
|
||||
pad_token_label=pad_token_label,
|
||||
only_label_first_subword=only_label_first_subword,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
||||
|
||||
self.add_prefix_space = add_prefix_space
|
||||
|
||||
tokenizer_component = "post_processor"
|
||||
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
||||
if tokenizer_component_instance:
|
||||
state = json.loads(tokenizer_component_instance.__getstate__())
|
||||
|
||||
# The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class`
|
||||
if "sep" in state:
|
||||
state["sep"] = tuple(state["sep"])
|
||||
if "cls" in state:
|
||||
state["cls"] = tuple(state["cls"])
|
||||
|
||||
changes_to_apply = False
|
||||
|
||||
if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
||||
state["add_prefix_space"] = add_prefix_space
|
||||
changes_to_apply = True
|
||||
|
||||
if state.get("trim_offsets", trim_offsets) != trim_offsets:
|
||||
state["trim_offsets"] = trim_offsets
|
||||
changes_to_apply = True
|
||||
|
||||
if changes_to_apply:
|
||||
component_class = getattr(processors, state.pop("type"))
|
||||
new_value = component_class(**state)
|
||||
setattr(self.backend_tokenizer, tokenizer_component, new_value)
|
||||
|
||||
# additional properties
|
||||
self.cls_token_box = cls_token_box
|
||||
self.sep_token_box = sep_token_box
|
||||
self.pad_token_box = pad_token_box
|
||||
self.pad_token_label = pad_token_label
|
||||
self.only_label_first_subword = only_label_first_subword
|
||||
|
||||
@add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
||||
# Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.__call__
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
|
||||
text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
|
||||
boxes: Union[List[List[int]], List[List[List[int]]]] = None,
|
||||
word_labels: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
|
||||
sequences with word-level normalized bounding boxes and optional labels.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings
|
||||
(words of a single example or questions of a batch of examples) or a list of list of strings (batch of
|
||||
words).
|
||||
text_pair (`List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence should be a list of strings
|
||||
(pretokenized string).
|
||||
boxes (`List[List[int]]`, `List[List[List[int]]]`):
|
||||
Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale.
|
||||
word_labels (`List[int]`, `List[List[int]]`, *optional*):
|
||||
Word-level integer labels (for token classification tasks such as FUNSD, CORD).
|
||||
"""
|
||||
# Input type checking for clearer error
|
||||
def _is_valid_text_input(t):
|
||||
if isinstance(t, str):
|
||||
# Strings are fine
|
||||
return True
|
||||
elif isinstance(t, (list, tuple)):
|
||||
# List are fine as long as they are...
|
||||
if len(t) == 0:
|
||||
# ... empty
|
||||
return True
|
||||
elif isinstance(t[0], str):
|
||||
# ... list of strings
|
||||
return True
|
||||
elif isinstance(t[0], (list, tuple)):
|
||||
# ... list with an empty list or with a list of strings
|
||||
return len(t[0]) == 0 or isinstance(t[0][0], str)
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
if text_pair is not None:
|
||||
# in case text + text_pair are provided, text = questions, text_pair = words
|
||||
if not _is_valid_text_input(text):
|
||||
raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ")
|
||||
if not isinstance(text_pair, (list, tuple)):
|
||||
raise ValueError(
|
||||
"Words must be of type `List[str]` (single pretokenized example), "
|
||||
"or `List[List[str]]` (batch of pretokenized examples)."
|
||||
)
|
||||
else:
|
||||
# in case only text is provided => must be words
|
||||
if not isinstance(text, (list, tuple)):
|
||||
raise ValueError(
|
||||
"Words must be of type `List[str]` (single pretokenized example), "
|
||||
"or `List[List[str]]` (batch of pretokenized examples)."
|
||||
)
|
||||
|
||||
if text_pair is not None:
|
||||
is_batched = isinstance(text, (list, tuple))
|
||||
else:
|
||||
is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
|
||||
|
||||
words = text if text_pair is None else text_pair
|
||||
if boxes is None:
|
||||
raise ValueError("You must provide corresponding bounding boxes")
|
||||
if is_batched:
|
||||
if len(words) != len(boxes):
|
||||
raise ValueError("You must provide words and boxes for an equal amount of examples")
|
||||
for words_example, boxes_example in zip(words, boxes):
|
||||
if len(words_example) != len(boxes_example):
|
||||
raise ValueError("You must provide as many words as there are bounding boxes")
|
||||
else:
|
||||
if len(words) != len(boxes):
|
||||
raise ValueError("You must provide as many words as there are bounding boxes")
|
||||
|
||||
if is_batched:
|
||||
if text_pair is not None and len(text) != len(text_pair):
|
||||
raise ValueError(
|
||||
f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
|
||||
f" {len(text_pair)}."
|
||||
)
|
||||
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
|
||||
is_pair = bool(text_pair is not None)
|
||||
return self.batch_encode_plus(
|
||||
batch_text_or_text_pairs=batch_text_or_text_pairs,
|
||||
is_pair=is_pair,
|
||||
boxes=boxes,
|
||||
word_labels=word_labels,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self.encode_plus(
|
||||
text=text,
|
||||
text_pair=text_pair,
|
||||
boxes=boxes,
|
||||
word_labels=word_labels,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
||||
# Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.batch_encode_plus
|
||||
def batch_encode_plus(
|
||||
self,
|
||||
batch_text_or_text_pairs: Union[
|
||||
List[TextInput],
|
||||
List[TextInputPair],
|
||||
List[PreTokenizedInput],
|
||||
],
|
||||
is_pair: bool = None,
|
||||
boxes: Optional[List[List[List[int]]]] = None,
|
||||
word_labels: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
|
||||
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
|
||||
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self._batch_encode_plus(
|
||||
batch_text_or_text_pairs=batch_text_or_text_pairs,
|
||||
is_pair=is_pair,
|
||||
boxes=boxes,
|
||||
word_labels=word_labels,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding_strategy=padding_strategy,
|
||||
truncation_strategy=truncation_strategy,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.tokenize
|
||||
def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
|
||||
batched_input = [(text, pair)] if pair else [text]
|
||||
encodings = self._tokenizer.encode_batch(
|
||||
batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs
|
||||
)
|
||||
|
||||
return encodings[0].tokens
|
||||
|
||||
@add_end_docstrings(LAYOUTLMV3_ENCODE_KWARGS_DOCSTRING, LAYOUTLMV3_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
||||
# Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.encode_plus
|
||||
def encode_plus(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput],
|
||||
text_pair: Optional[PreTokenizedInput] = None,
|
||||
boxes: Optional[List[List[int]]] = None,
|
||||
word_labels: Optional[List[int]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated,
|
||||
`__call__` should be used instead.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
|
||||
text_pair (`List[str]` or `List[int]`, *optional*):
|
||||
Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
|
||||
list of list of strings (words of a batch of examples).
|
||||
"""
|
||||
|
||||
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
|
||||
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self._encode_plus(
|
||||
text=text,
|
||||
boxes=boxes,
|
||||
text_pair=text_pair,
|
||||
word_labels=word_labels,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding_strategy=padding_strategy,
|
||||
truncation_strategy=truncation_strategy,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast._batch_encode_plus with LayoutLMv2->LayoutLMv3
|
||||
def _batch_encode_plus(
|
||||
self,
|
||||
batch_text_or_text_pairs: Union[
|
||||
List[TextInput],
|
||||
List[TextInputPair],
|
||||
List[PreTokenizedInput],
|
||||
],
|
||||
is_pair: bool = None,
|
||||
boxes: Optional[List[List[List[int]]]] = None,
|
||||
word_labels: Optional[List[List[int]]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[str] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
) -> BatchEncoding:
|
||||
|
||||
if not isinstance(batch_text_or_text_pairs, list):
|
||||
raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})")
|
||||
|
||||
# Set the truncation and padding strategy and restore the initial configuration
|
||||
self.set_truncation_and_padding(
|
||||
padding_strategy=padding_strategy,
|
||||
truncation_strategy=truncation_strategy,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
)
|
||||
|
||||
if is_pair:
|
||||
batch_text_or_text_pairs = [(text.split(), text_pair) for text, text_pair in batch_text_or_text_pairs]
|
||||
|
||||
encodings = self._tokenizer.encode_batch(
|
||||
batch_text_or_text_pairs,
|
||||
add_special_tokens=add_special_tokens,
|
||||
is_pretokenized=True, # we set this to True as LayoutLMv3 always expects pretokenized inputs
|
||||
)
|
||||
|
||||
# Convert encoding to dict
|
||||
# `Tokens` has type: Tuple[
|
||||
# List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
|
||||
# List[EncodingFast]
|
||||
# ]
|
||||
# with nested dimensions corresponding to batch, overflows, sequence length
|
||||
tokens_and_encodings = [
|
||||
self._convert_encoding(
|
||||
encoding=encoding,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=True
|
||||
if word_labels is not None
|
||||
else return_offsets_mapping, # we use offsets to create the labels
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
)
|
||||
for encoding in encodings
|
||||
]
|
||||
|
||||
# Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
|
||||
# From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
|
||||
# (we say ~ because the number of overflow varies with the example in the batch)
|
||||
#
|
||||
# To match each overflowing sample with the original sample in the batch
|
||||
# we add an overflow_to_sample_mapping array (see below)
|
||||
sanitized_tokens = {}
|
||||
for key in tokens_and_encodings[0][0].keys():
|
||||
stack = [e for item, _ in tokens_and_encodings for e in item[key]]
|
||||
sanitized_tokens[key] = stack
|
||||
sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
|
||||
|
||||
# If returning overflowing tokens, we need to return a mapping
|
||||
# from the batch idx to the original sample
|
||||
if return_overflowing_tokens:
|
||||
overflow_to_sample_mapping = []
|
||||
for i, (toks, _) in enumerate(tokens_and_encodings):
|
||||
overflow_to_sample_mapping += [i] * len(toks["input_ids"])
|
||||
sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
|
||||
|
||||
for input_ids in sanitized_tokens["input_ids"]:
|
||||
self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
|
||||
|
||||
# create the token boxes
|
||||
token_boxes = []
|
||||
for batch_index in range(len(sanitized_tokens["input_ids"])):
|
||||
if return_overflowing_tokens:
|
||||
original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
|
||||
else:
|
||||
original_index = batch_index
|
||||
token_boxes_example = []
|
||||
for id, sequence_id, word_id in zip(
|
||||
sanitized_tokens["input_ids"][batch_index],
|
||||
sanitized_encodings[batch_index].sequence_ids,
|
||||
sanitized_encodings[batch_index].word_ids,
|
||||
):
|
||||
if word_id is not None:
|
||||
if is_pair and sequence_id == 0:
|
||||
token_boxes_example.append(self.pad_token_box)
|
||||
else:
|
||||
token_boxes_example.append(boxes[original_index][word_id])
|
||||
else:
|
||||
if id == self.cls_token_id:
|
||||
token_boxes_example.append(self.cls_token_box)
|
||||
elif id == self.sep_token_id:
|
||||
token_boxes_example.append(self.sep_token_box)
|
||||
elif id == self.pad_token_id:
|
||||
token_boxes_example.append(self.pad_token_box)
|
||||
else:
|
||||
raise ValueError("Id not recognized")
|
||||
token_boxes.append(token_boxes_example)
|
||||
|
||||
sanitized_tokens["bbox"] = token_boxes
|
||||
|
||||
# optionally, create the labels
|
||||
if word_labels is not None:
|
||||
labels = []
|
||||
for batch_index in range(len(sanitized_tokens["input_ids"])):
|
||||
if return_overflowing_tokens:
|
||||
original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
|
||||
else:
|
||||
original_index = batch_index
|
||||
labels_example = []
|
||||
for id, offset, word_id in zip(
|
||||
sanitized_tokens["input_ids"][batch_index],
|
||||
sanitized_tokens["offset_mapping"][batch_index],
|
||||
sanitized_encodings[batch_index].word_ids,
|
||||
):
|
||||
if word_id is not None:
|
||||
if self.only_label_first_subword:
|
||||
if offset[0] == 0:
|
||||
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
|
||||
labels_example.append(word_labels[original_index][word_id])
|
||||
else:
|
||||
labels_example.append(self.pad_token_label)
|
||||
else:
|
||||
labels_example.append(word_labels[original_index][word_id])
|
||||
else:
|
||||
labels_example.append(self.pad_token_label)
|
||||
labels.append(labels_example)
|
||||
|
||||
sanitized_tokens["labels"] = labels
|
||||
# finally, remove offsets if the user didn't want them
|
||||
if not return_offsets_mapping:
|
||||
del sanitized_tokens["offset_mapping"]
|
||||
|
||||
return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
|
||||
|
||||
# Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast._encode_plus
|
||||
def _encode_plus(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput],
|
||||
text_pair: Optional[PreTokenizedInput] = None,
|
||||
boxes: Optional[List[List[int]]] = None,
|
||||
word_labels: Optional[List[int]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[bool] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
|
||||
# make it a batched input
|
||||
# 2 options:
|
||||
# 1) only text, in case text must be a list of str
|
||||
# 2) text + text_pair, in which case text = str and text_pair a list of str
|
||||
batched_input = [(text, text_pair)] if text_pair else [text]
|
||||
batched_boxes = [boxes]
|
||||
batched_word_labels = [word_labels] if word_labels is not None else None
|
||||
batched_output = self._batch_encode_plus(
|
||||
batched_input,
|
||||
is_pair=bool(text_pair is not None),
|
||||
boxes=batched_boxes,
|
||||
word_labels=batched_word_labels,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding_strategy=padding_strategy,
|
||||
truncation_strategy=truncation_strategy,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Return tensor is None, then we can remove the leading batch axis
|
||||
# Overflowing tokens are returned as a batch of output so we keep them in this case
|
||||
if return_tensors is None and not return_overflowing_tokens:
|
||||
batched_output = BatchEncoding(
|
||||
{
|
||||
key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
|
||||
for key, value in batched_output.items()
|
||||
},
|
||||
batched_output.encodings,
|
||||
)
|
||||
|
||||
self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
|
||||
|
||||
return batched_output
|
||||
|
||||
# Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast._pad
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
max_length: Optional[int] = None,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
|
||||
Args:
|
||||
encoded_inputs:
|
||||
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
||||
max_length: maximum length of the returned list and optionally padding length (see below).
|
||||
Will truncate by taking into account the special tokens.
|
||||
padding_strategy: PaddingStrategy to use for padding.
|
||||
|
||||
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
||||
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
||||
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
||||
The tokenizer padding sides are defined in self.padding_side:
|
||||
|
||||
- 'left': pads on the left of the sequences
|
||||
- 'right': pads on the right of the sequences
|
||||
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
||||
>= 7.5 (Volta).
|
||||
return_attention_mask:
|
||||
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
||||
"""
|
||||
# Load from model defaults
|
||||
if return_attention_mask is None:
|
||||
return_attention_mask = "attention_mask" in self.model_input_names
|
||||
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = len(required_input)
|
||||
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
# Initialize attention mask if not present.
|
||||
if return_attention_mask and "attention_mask" not in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
if self.padding_side == "right":
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
|
||||
if "token_type_ids" in encoded_inputs:
|
||||
encoded_inputs["token_type_ids"] = (
|
||||
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
|
||||
)
|
||||
if "bbox" in encoded_inputs:
|
||||
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference
|
||||
if "labels" in encoded_inputs:
|
||||
encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference
|
||||
if "special_tokens_mask" in encoded_inputs:
|
||||
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
|
||||
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
|
||||
elif self.padding_side == "left":
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
||||
if "token_type_ids" in encoded_inputs:
|
||||
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
||||
"token_type_ids"
|
||||
]
|
||||
if "bbox" in encoded_inputs:
|
||||
encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"]
|
||||
if "labels" in encoded_inputs:
|
||||
encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"]
|
||||
if "special_tokens_mask" in encoded_inputs:
|
||||
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
else:
|
||||
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
# Copied from transformers.models.layoutlmv2.tokenization_layoutlmv2_fast.LayoutLMv2TokenizerFast.save_vocabulary
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||
return tuple(files)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
|
||||
if token_ids_1 is None:
|
||||
return output
|
||||
|
||||
return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Args:
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not:
|
||||
make use of token type ids, therefore a list of zeros is returned.
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
Returns:
|
||||
`List[int]`: List of zeros.
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
@ -28,7 +28,9 @@ from ...utils import (
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {}
|
||||
_import_structure = {
|
||||
"processing_layoutxlm": ["LayoutXLMProcessor"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_sentencepiece_available():
|
||||
@ -46,15 +48,9 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["tokenization_layoutxlm_fast"] = ["LayoutXLMTokenizerFast"]
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["processing_layoutxlm"] = ["LayoutXLMProcessor"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .processing_layoutxlm import LayoutXLMProcessor
|
||||
|
||||
try:
|
||||
if not is_sentencepiece_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@ -71,14 +67,6 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
from .tokenization_layoutxlm_fast import LayoutXLMTokenizerFast
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .processing_layoutlmv2 import LayoutXLMProcessor
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
@ -121,6 +121,37 @@ class LayoutXLMProcessor(ProcessorMixin):
|
||||
)
|
||||
|
||||
# add pixel values
|
||||
encoded_inputs["image"] = features.pop("pixel_values")
|
||||
images = features.pop("pixel_values")
|
||||
if return_overflowing_tokens is True:
|
||||
images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"])
|
||||
encoded_inputs["image"] = images
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def get_overflowing_images(self, images, overflow_to_sample_mapping):
|
||||
# in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image
|
||||
images_with_overflow = []
|
||||
for sample_idx in overflow_to_sample_mapping:
|
||||
images_with_overflow.append(images[sample_idx])
|
||||
|
||||
if len(images_with_overflow) != len(overflow_to_sample_mapping):
|
||||
raise ValueError(
|
||||
"Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got"
|
||||
f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}"
|
||||
)
|
||||
|
||||
return images_with_overflow
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
|
||||
to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
@ -2367,6 +2367,44 @@ class LayoutLMv2PreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class LayoutLMv3ForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LayoutLMv3ForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LayoutLMv3ForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LayoutLMv3Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LayoutLMv3PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
LED_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -171,6 +171,13 @@ class LayoutLMv2TokenizerFast(metaclass=DummyObject):
|
||||
requires_backends(self, ["tokenizers"])
|
||||
|
||||
|
||||
class LayoutLMv3TokenizerFast(metaclass=DummyObject):
|
||||
_backends = ["tokenizers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tokenizers"])
|
||||
|
||||
|
||||
class LayoutXLMTokenizerFast(metaclass=DummyObject):
|
||||
_backends = ["tokenizers"]
|
||||
|
||||
|
@ -94,14 +94,7 @@ class LayoutLMv2FeatureExtractor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class LayoutLMv2Processor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class LayoutXLMProcessor(metaclass=DummyObject):
|
||||
class LayoutLMv3FeatureExtractor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -215,10 +215,11 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
# verify input_ids
|
||||
# this was obtained with Tesseract 4.1.1
|
||||
# fmt: off
|
||||
expected_decoding = "[CLS] 11 : 14 to 11 : 39 a. m 11 : 39 to 11 : 44 a. m. 11 : 44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president “ introductory remarks ” lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from the floor. dr. emil m. mrak, university of cal - ifornia, chairman, trrf board ; sam r. cecil, university of georgia college of agriculture ; dr. stanley charm, tufts university school of medicine ; dr. robert h. cotton, itt continental baking company ; dr. owen fennema, university of wis - consin ; dr. robert e. hardenburg, usda. questions and answers exhibits open capt. jack stoney room trrf scientific advisory council meeting ballroom foyer [SEP]" # noqa: E231
|
||||
# fmt: on
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
@ -236,10 +237,11 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
# verify input_ids
|
||||
# this was obtained with Tesseract 4.1.1
|
||||
# fmt: off
|
||||
expected_decoding = "[CLS] 7 itc limited report and accounts 2013 itc ’ s brands : an asset for the nation the consumer needs and aspirations they fulfil, the benefit they generate for millions across itc ’ s value chains, the future - ready capabilities that support them, and the value that they create for the country, have made itc ’ s brands national assets, adding to india ’ s competitiveness. it is itc ’ s aspiration to be the no 1 fmcg player in the country, driven by its new fmcg businesses. a recent nielsen report has highlighted that itc's new fmcg businesses are the fastest growing among the top consumer goods companies operating in india. itc takes justifiable pride that, along with generating economic value, these celebrated indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. di wills * ; love delightfully soft skin? aia ans source : https : / / www. industrydocuments. ucsf. edu / docs / snbx0223 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]" # noqa: E231
|
||||
# fmt: on
|
||||
decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
|
||||
decoding = processor.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
@slow
|
||||
@ -266,7 +268,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "[CLS] hello world [SEP]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
@ -281,7 +283,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "[CLS] hello world [SEP] [PAD] [PAD] [PAD]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
|
||||
decoding = processor.decode(input_processor.input_ids[0].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
@ -320,7 +322,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "[CLS] weirdly world [SEP]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify labels
|
||||
@ -342,7 +344,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "[CLS] my name is niels [SEP]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
|
||||
decoding = processor.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
@ -382,10 +384,11 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
# this was obtained with Tesseract 4.1.1
|
||||
# fmt: off
|
||||
expected_decoding = "[CLS] what's his name? [SEP] 11 : 14 to 11 : 39 a. m 11 : 39 to 11 : 44 a. m. 11 : 44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president “ introductory remarks ” lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from the floor. dr. emil m. mrak, university of cal - ifornia, chairman, trrf board ; sam r. cecil, university of georgia college of agriculture ; dr. stanley charm, tufts university school of medicine ; dr. robert h. cotton, itt continental baking company ; dr. owen fennema, university of wis - consin ; dr. robert e. hardenburg, usda. questions and answers exhibits open capt. jack stoney room trrf scientific advisory council meeting ballroom foyer [SEP]" # noqa: E231
|
||||
# fmt: on
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
@ -400,8 +403,9 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
# this was obtained with Tesseract 4.1.1
|
||||
expected_decoding = "[CLS] what's the time [SEP] 7 itc limited report and accounts 2013 itc ’ s [SEP]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
|
||||
decoding = processor.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
@ -434,7 +438,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "[CLS] what's his name? [SEP] hello world [SEP]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
@ -450,11 +454,11 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "[CLS] how old is he? [SEP] hello world [SEP] [PAD] [PAD] [PAD]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
|
||||
decoding = processor.decode(input_processor.input_ids[0].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
expected_decoding = "[CLS] what's the time [SEP] my name is niels [SEP]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
|
||||
decoding = processor.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
|
0
tests/models/layoutlmv3/__init__.py
Normal file
0
tests/models/layoutlmv3/__init__.py
Normal file
213
tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py
Normal file
213
tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py
Normal file
@ -0,0 +1,213 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_pytesseract, require_torch
|
||||
from transformers.utils import is_pytesseract_available, is_torch_available
|
||||
|
||||
from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_pytesseract_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import LayoutLMv3FeatureExtractor
|
||||
|
||||
|
||||
class LayoutLMv3FeatureExtractionTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=18,
|
||||
apply_ocr=True,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.apply_ocr = apply_ocr
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {"do_resize": self.do_resize, "size": self.size, "apply_ocr": self.apply_ocr}
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_pytesseract
|
||||
class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
|
||||
|
||||
feature_extraction_class = LayoutLMv3FeatureExtractor if is_pytesseract_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.feature_extract_tester = LayoutLMv3FeatureExtractionTester(self)
|
||||
|
||||
@property
|
||||
def feat_extract_dict(self):
|
||||
return self.feature_extract_tester.prepare_feat_extract_dict()
|
||||
|
||||
def test_feat_extract_properties(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
self.assertTrue(hasattr(feature_extractor, "do_resize"))
|
||||
self.assertTrue(hasattr(feature_extractor, "size"))
|
||||
self.assertTrue(hasattr(feature_extractor, "apply_ocr"))
|
||||
|
||||
def test_batch_feature(self):
|
||||
pass
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random PIL images
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoding = feature_extractor(image_inputs[0], return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding.pixel_values.shape,
|
||||
(
|
||||
1,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
self.assertIsInstance(encoding.words, list)
|
||||
self.assertIsInstance(encoding.boxes, list)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
1,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
1,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
def test_LayoutLMv3_integration_test(self):
|
||||
# with apply_OCR = True
|
||||
feature_extractor = LayoutLMv3FeatureExtractor()
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/fixtures_docvqa", split="test")
|
||||
|
||||
image = Image.open(ds[0]["file"]).convert("RGB")
|
||||
|
||||
encoding = feature_extractor(image, return_tensors="pt")
|
||||
|
||||
self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224))
|
||||
self.assertEqual(len(encoding.words), len(encoding.boxes))
|
||||
|
||||
# fmt: off
|
||||
# the words and boxes were obtained with Tesseract 4.1.1
|
||||
expected_words = [['11:14', 'to', '11:39', 'a.m', '11:39', 'to', '11:44', 'a.m.', '11:44', 'a.m.', 'to', '12:25', 'p.m.', '12:25', 'to', '12:58', 'p.m.', '12:58', 'to', '4:00', 'p.m.', '2:00', 'to', '5:00', 'p.m.', 'Coffee', 'Break', 'Coffee', 'will', 'be', 'served', 'for', 'men', 'and', 'women', 'in', 'the', 'lobby', 'adjacent', 'to', 'exhibit', 'area.', 'Please', 'move', 'into', 'exhibit', 'area.', '(Exhibits', 'Open)', 'TRRF', 'GENERAL', 'SESSION', '(PART', '|)', 'Presiding:', 'Lee', 'A.', 'Waller', 'TRRF', 'Vice', 'President', '“Introductory', 'Remarks”', 'Lee', 'A.', 'Waller,', 'TRRF', 'Vice', 'Presi-', 'dent', 'Individual', 'Interviews', 'with', 'TRRF', 'Public', 'Board', 'Members', 'and', 'Sci-', 'entific', 'Advisory', 'Council', 'Mem-', 'bers', 'Conducted', 'by', 'TRRF', 'Treasurer', 'Philip', 'G.', 'Kuehn', 'to', 'get', 'answers', 'which', 'the', 'public', 'refrigerated', 'warehousing', 'industry', 'is', 'looking', 'for.', 'Plus', 'questions', 'from', 'the', 'floor.', 'Dr.', 'Emil', 'M.', 'Mrak,', 'University', 'of', 'Cal-', 'ifornia,', 'Chairman,', 'TRRF', 'Board;', 'Sam', 'R.', 'Cecil,', 'University', 'of', 'Georgia', 'College', 'of', 'Agriculture;', 'Dr.', 'Stanley', 'Charm,', 'Tufts', 'University', 'School', 'of', 'Medicine;', 'Dr.', 'Robert', 'H.', 'Cotton,', 'ITT', 'Continental', 'Baking', 'Company;', 'Dr.', 'Owen', 'Fennema,', 'University', 'of', 'Wis-', 'consin;', 'Dr.', 'Robert', 'E.', 'Hardenburg,', 'USDA.', 'Questions', 'and', 'Answers', 'Exhibits', 'Open', 'Capt.', 'Jack', 'Stoney', 'Room', 'TRRF', 'Scientific', 'Advisory', 'Council', 'Meeting', 'Ballroom', 'Foyer']] # noqa: E231
|
||||
expected_boxes = [[[141, 57, 214, 69], [228, 58, 252, 69], [141, 75, 216, 88], [230, 79, 280, 88], [142, 260, 218, 273], [230, 261, 255, 273], [143, 279, 218, 290], [231, 282, 290, 291], [143, 342, 218, 354], [231, 345, 289, 355], [202, 362, 227, 373], [143, 379, 220, 392], [231, 382, 291, 394], [144, 714, 220, 726], [231, 715, 256, 726], [144, 732, 220, 745], [232, 736, 291, 747], [144, 769, 218, 782], [231, 770, 256, 782], [141, 788, 202, 801], [215, 791, 274, 804], [143, 826, 204, 838], [215, 826, 240, 838], [142, 844, 202, 857], [215, 847, 274, 859], [334, 57, 427, 69], [440, 57, 522, 69], [369, 75, 461, 88], [469, 75, 516, 88], [528, 76, 562, 88], [570, 76, 667, 88], [675, 75, 711, 87], [721, 79, 778, 88], [789, 75, 840, 88], [369, 97, 470, 107], [484, 94, 507, 106], [518, 94, 562, 107], [576, 94, 655, 110], [668, 94, 792, 109], [804, 95, 829, 107], [369, 113, 465, 125], [477, 116, 547, 125], [562, 113, 658, 125], [671, 116, 748, 125], [761, 113, 811, 125], [369, 131, 465, 143], [477, 133, 548, 143], [563, 130, 698, 145], [710, 130, 802, 146], [336, 171, 412, 183], [423, 171, 572, 183], [582, 170, 716, 184], [728, 171, 817, 187], [829, 171, 844, 186], [338, 197, 482, 212], [507, 196, 557, 209], [569, 196, 595, 208], [610, 196, 702, 209], [505, 214, 583, 226], [595, 214, 656, 227], [670, 215, 807, 227], [335, 259, 543, 274], [556, 259, 708, 272], [372, 279, 422, 291], [435, 279, 460, 291], [474, 279, 574, 292], [587, 278, 664, 291], [676, 278, 738, 291], [751, 279, 834, 291], [372, 298, 434, 310], [335, 341, 483, 354], [497, 341, 655, 354], [667, 341, 728, 354], [740, 341, 825, 354], [335, 360, 430, 372], [442, 360, 534, 372], [545, 359, 687, 372], [697, 360, 754, 372], [765, 360, 823, 373], [334, 378, 428, 391], [440, 378, 577, 394], [590, 378, 705, 391], [720, 378, 801, 391], [334, 397, 400, 409], [370, 416, 529, 429], [544, 416, 576, 432], [587, 416, 665, 428], [677, 416, 814, 429], [372, 435, 452, 450], [465, 434, 495, 447], [511, 434, 600, 447], [611, 436, 637, 447], [649, 436, 694, 451], [705, 438, 824, 447], [369, 453, 452, 466], [464, 454, 509, 466], [522, 453, 611, 469], [625, 453, 792, 469], [370, 472, 556, 488], [570, 472, 684, 487], [697, 472, 718, 485], [732, 472, 835, 488], [369, 490, 411, 503], [425, 490, 484, 503], [496, 490, 635, 506], [645, 490, 707, 503], [718, 491, 761, 503], [771, 490, 840, 503], [336, 510, 374, 521], [388, 510, 447, 522], [460, 510, 489, 521], [503, 510, 580, 522], [592, 509, 736, 525], [745, 509, 770, 522], [781, 509, 840, 522], [338, 528, 434, 541], [448, 528, 596, 541], [609, 527, 687, 540], [700, 528, 792, 541], [336, 546, 397, 559], [407, 546, 431, 559], [443, 546, 525, 560], [537, 546, 680, 562], [688, 546, 714, 559], [722, 546, 837, 562], [336, 565, 449, 581], [461, 565, 485, 577], [497, 565, 665, 581], [681, 565, 718, 577], [732, 565, 837, 580], [337, 584, 438, 597], [452, 583, 521, 596], [535, 584, 677, 599], [690, 583, 787, 596], [801, 583, 825, 596], [338, 602, 478, 615], [492, 602, 530, 614], [543, 602, 638, 615], [650, 602, 676, 614], [688, 602, 788, 615], [802, 602, 843, 614], [337, 621, 502, 633], [516, 621, 615, 637], [629, 621, 774, 636], [789, 621, 827, 633], [337, 639, 418, 652], [432, 640, 571, 653], [587, 639, 731, 655], [743, 639, 769, 652], [780, 639, 841, 652], [338, 658, 440, 673], [455, 658, 491, 670], [508, 658, 602, 671], [616, 658, 638, 670], [654, 658, 835, 674], [337, 677, 429, 689], [337, 714, 482, 726], [495, 714, 548, 726], [561, 714, 683, 726], [338, 770, 461, 782], [474, 769, 554, 785], [489, 788, 562, 803], [576, 788, 643, 801], [656, 787, 751, 804], [764, 788, 844, 801], [334, 825, 421, 838], [430, 824, 574, 838], [584, 824, 723, 841], [335, 844, 450, 857], [464, 843, 583, 860], [628, 862, 755, 875], [769, 861, 848, 878]]] # noqa: E231
|
||||
# fmt: on
|
||||
|
||||
self.assertListEqual(encoding.words, expected_words)
|
||||
self.assertListEqual(encoding.boxes, expected_boxes)
|
||||
|
||||
# with apply_OCR = False
|
||||
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
|
||||
|
||||
encoding = feature_extractor(image, return_tensors="pt")
|
||||
|
||||
self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224))
|
399
tests/models/layoutlmv3/test_modeling_layoutlmv3.py
Normal file
399
tests/models/layoutlmv3/test_modeling_layoutlmv3.py
Normal file
@ -0,0 +1,399 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch LayoutLMv3 model. """
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
LayoutLMv3Config,
|
||||
LayoutLMv3ForQuestionAnswering,
|
||||
LayoutLMv3ForSequenceClassification,
|
||||
LayoutLMv3ForTokenClassification,
|
||||
LayoutLMv3Model,
|
||||
)
|
||||
from transformers.models.layoutlmv3.modeling_layoutlmv3 import LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import LayoutLMv3FeatureExtractor
|
||||
|
||||
|
||||
class LayoutLMv3ModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=2,
|
||||
num_channels=3,
|
||||
image_size=4,
|
||||
patch_size=2,
|
||||
text_seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=36,
|
||||
num_hidden_layers=3,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
coordinate_size=6,
|
||||
shape_size=6,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
range_bbox=1000,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.text_seq_length = text_seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.coordinate_size = coordinate_size
|
||||
self.shape_size = shape_size
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.range_bbox = range_bbox
|
||||
|
||||
# LayoutLMv3's sequence length equals the number of text tokens + number of patches + 1 (we add 1 for the CLS token)
|
||||
self.text_seq_length = text_seq_length
|
||||
self.image_seq_length = (image_size // patch_size) ** 2 + 1
|
||||
self.seq_length = self.text_seq_length + self.image_seq_length
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.text_seq_length], self.vocab_size)
|
||||
|
||||
bbox = ids_tensor([self.batch_size, self.text_seq_length, 4], self.range_bbox)
|
||||
# Ensure that bbox is legal
|
||||
for i in range(bbox.shape[0]):
|
||||
for j in range(bbox.shape[1]):
|
||||
if bbox[i, j, 3] < bbox[i, j, 1]:
|
||||
t = bbox[i, j, 3]
|
||||
bbox[i, j, 3] = bbox[i, j, 1]
|
||||
bbox[i, j, 1] = t
|
||||
if bbox[i, j, 2] < bbox[i, j, 0]:
|
||||
t = bbox[i, j, 2]
|
||||
bbox[i, j, 2] = bbox[i, j, 0]
|
||||
bbox[i, j, 0] = t
|
||||
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.text_seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.text_seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.text_seq_length], self.num_labels)
|
||||
|
||||
config = LayoutLMv3Config(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
coordinate_size=self.coordinate_size,
|
||||
shape_size=self.shape_size,
|
||||
input_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
)
|
||||
|
||||
return config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
model = LayoutLMv3Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# text + image
|
||||
result = model(input_ids, pixel_values=pixel_values)
|
||||
result = model(
|
||||
input_ids, bbox=bbox, pixel_values=pixel_values, attention_mask=input_mask, token_type_ids=token_type_ids
|
||||
)
|
||||
result = model(input_ids, bbox=bbox, pixel_values=pixel_values, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, bbox=bbox, pixel_values=pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
# text only
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.text_seq_length, self.hidden_size)
|
||||
)
|
||||
|
||||
# image only
|
||||
result = model(pixel_values=pixel_values)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.image_seq_length, self.hidden_size)
|
||||
)
|
||||
|
||||
def create_and_check_for_sequence_classification(
|
||||
self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = LayoutLMv3ForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
bbox=bbox,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_for_token_classification(
|
||||
self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = LayoutLMv3ForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
bbox=bbox,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=token_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.text_seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_for_question_answering(
|
||||
self, config, input_ids, bbox, pixel_values, token_type_ids, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
model = LayoutLMv3ForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
bbox=bbox,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
bbox,
|
||||
pixel_values,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"bbox": bbox,
|
||||
"pixel_values": pixel_values,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_mask": input_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class LayoutLMv3ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_mismatched_shapes = False
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
LayoutLMv3Model,
|
||||
LayoutLMv3ForSequenceClassification,
|
||||
LayoutLMv3ForTokenClassification,
|
||||
LayoutLMv3ForQuestionAnswering,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = LayoutLMv3ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=LayoutLMv3Config, hidden_size=37)
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
inputs_dict = {
|
||||
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
||||
if isinstance(v, torch.Tensor) and v.ndim > 1
|
||||
else v
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
if return_labels:
|
||||
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
|
||||
elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
inputs_dict["start_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
inputs_dict["end_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class in [
|
||||
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class in [
|
||||
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.text_seq_length),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in LAYOUTLMV3_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = LayoutLMv3Model.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_torch
|
||||
class LayoutLMv3ModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return LayoutLMv3FeatureExtractor(apply_ocr=False) if is_vision_available() else None
|
||||
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base").to(torch_device)
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(torch_device)
|
||||
|
||||
input_ids = torch.tensor([[1, 2]])
|
||||
bbox = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).unsqueeze(0)
|
||||
|
||||
# forward pass
|
||||
outputs = model(
|
||||
input_ids=input_ids.to(torch_device),
|
||||
bbox=bbox.to(torch_device),
|
||||
pixel_values=pixel_values.to(torch_device),
|
||||
)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 199, 768))
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-0.0529, 0.3618, 0.1632], [-0.1587, -0.1667, -0.0400], [-0.1557, -0.1671, -0.0505]]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
446
tests/models/layoutlmv3/test_processor_layoutlmv3.py
Normal file
446
tests/models/layoutlmv3/test_processor_layoutlmv3.py
Normal file
@ -0,0 +1,446 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
|
||||
from transformers.models.layoutlmv3 import LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast
|
||||
from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_pytesseract, require_tokenizers, require_torch, slow
|
||||
from transformers.utils import FEATURE_EXTRACTOR_NAME, cached_property, is_pytesseract_available
|
||||
|
||||
|
||||
if is_pytesseract_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3Processor
|
||||
|
||||
|
||||
@require_pytesseract
|
||||
@require_tokenizers
|
||||
class LayoutLMv3ProcessorTest(unittest.TestCase):
|
||||
tokenizer_class = LayoutLMv3Tokenizer
|
||||
rust_tokenizer_class = LayoutLMv3TokenizerFast
|
||||
|
||||
def setUp(self):
|
||||
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||
vocab = [
|
||||
"l",
|
||||
"o",
|
||||
"w",
|
||||
"e",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"i",
|
||||
"d",
|
||||
"n",
|
||||
"\u0120",
|
||||
"\u0120l",
|
||||
"\u0120n",
|
||||
"\u0120lo",
|
||||
"\u0120low",
|
||||
"er",
|
||||
"\u0120lowest",
|
||||
"\u0120newer",
|
||||
"\u0120wider",
|
||||
"<unk>",
|
||||
]
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
feature_extractor_map = {
|
||||
"do_resize": True,
|
||||
"size": 224,
|
||||
"apply_ocr": True,
|
||||
}
|
||||
|
||||
self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
|
||||
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(feature_extractor_map) + "\n")
|
||||
|
||||
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
|
||||
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
|
||||
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]:
|
||||
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
|
||||
|
||||
def get_feature_extractor(self, **kwargs):
|
||||
return LayoutLMv3FeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_save_load_pretrained_default(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
processor = LayoutLMv3Processor.from_pretrained(self.tmpdirname)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, (LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast))
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor)
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = LayoutLMv3Processor(feature_extractor=self.get_feature_extractor(), tokenizer=self.get_tokenizer())
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
# slow tokenizer
|
||||
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
||||
feature_extractor_add_kwargs = self.get_feature_extractor(do_resize=False, size=30)
|
||||
|
||||
processor = LayoutLMv3Processor.from_pretrained(
|
||||
self.tmpdirname, use_fast=False, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30
|
||||
)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, LayoutLMv3Tokenizer)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor)
|
||||
|
||||
# fast tokenizer
|
||||
tokenizer_add_kwargs = self.get_rust_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
||||
feature_extractor_add_kwargs = self.get_feature_extractor(do_resize=False, size=30)
|
||||
|
||||
processor = LayoutLMv3Processor.from_pretrained(
|
||||
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30
|
||||
)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, LayoutLMv3TokenizerFast)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor)
|
||||
|
||||
|
||||
# different use cases tests
|
||||
@require_torch
|
||||
@require_pytesseract
|
||||
class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def get_images(self):
|
||||
# we verify our implementation on 2 document images from the DocVQA dataset
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/fixtures_docvqa", split="test")
|
||||
|
||||
image_1 = Image.open(ds[0]["file"]).convert("RGB")
|
||||
image_2 = Image.open(ds[1]["file"]).convert("RGB")
|
||||
|
||||
return image_1, image_2
|
||||
|
||||
@cached_property
|
||||
def get_tokenizers(self):
|
||||
slow_tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False)
|
||||
fast_tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base", add_visual_labels=False)
|
||||
return [slow_tokenizer, fast_tokenizer]
|
||||
|
||||
@slow
|
||||
def test_processor_case_1(self):
|
||||
# case 1: document image classification (training, inference) + token classification (inference), apply_ocr = True
|
||||
|
||||
feature_extractor = LayoutLMv3FeatureExtractor()
|
||||
tokenizers = self.get_tokenizers
|
||||
images = self.get_images
|
||||
|
||||
for tokenizer in tokenizers:
|
||||
processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
# not batched
|
||||
input_feat_extract = feature_extractor(images[0], return_tensors="pt")
|
||||
input_processor = processor(images[0], return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify image
|
||||
self.assertAlmostEqual(
|
||||
input_feat_extract["pixel_values"].sum(), input_processor["pixel_values"].sum(), delta=1e-2
|
||||
)
|
||||
|
||||
# verify input_ids
|
||||
# this was obtained with Tesseract 4.1.1
|
||||
# fmt: off
|
||||
expected_decoding = "<s> 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President “Introductory Remarks” Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer</s>" # noqa: E231
|
||||
# fmt: on
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
input_feat_extract = feature_extractor(images, return_tensors="pt")
|
||||
input_processor = processor(images, padding=True, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify images
|
||||
self.assertAlmostEqual(
|
||||
input_feat_extract["pixel_values"].sum(), input_processor["pixel_values"].sum(), delta=1e-2
|
||||
)
|
||||
|
||||
# verify input_ids
|
||||
# this was obtained with Tesseract 4.1.1
|
||||
# fmt: off
|
||||
expected_decoding = "<s> 7 ITC Limited REPORT AND ACCOUNTS 2013 ITC’s Brands: An Asset for the Nation The consumer needs and aspirations they fulfil, the benefit they generate for millions across ITC’s value chains, the future-ready capabilities that support them, and the value that they create for the country, have made ITC’s brands national assets, adding to India’s competitiveness. It is ITC’s aspiration to be the No 1 FMCG player in the country, driven by its new FMCG businesses. A recent Nielsen report has highlighted that ITC's new FMCG businesses are the fastest growing among the top consumer goods companies operating in India. ITC takes justifiable pride that, along with generating economic value, these celebrated Indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. DI WILLS * ; LOVE DELIGHTFULLY SOFT SKIN? aia Ans Source: https://www.industrydocuments.ucsf.edu/docs/snbx0223</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>" # noqa: E231
|
||||
# fmt: on
|
||||
decoding = processor.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
@slow
|
||||
def test_processor_case_2(self):
|
||||
# case 2: document image classification (training, inference) + token classification (inference), apply_ocr=False
|
||||
|
||||
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
|
||||
tokenizers = self.get_tokenizers
|
||||
images = self.get_images
|
||||
|
||||
for tokenizer in tokenizers:
|
||||
processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
# not batched
|
||||
words = ["hello", "world"]
|
||||
boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
input_processor = processor(images[0], words, boxes=boxes, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["input_ids", "bbox", "attention_mask", "pixel_values"]
|
||||
actual_keys = list(input_processor.keys())
|
||||
for key in expected_keys:
|
||||
self.assertIn(key, actual_keys)
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "<s> hello world</s>"
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
words = [["hello", "world"], ["my", "name", "is", "niels"]]
|
||||
boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
|
||||
input_processor = processor(images, words, boxes=boxes, padding=True, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "<s> hello world</s><pad><pad><pad>"
|
||||
decoding = processor.decode(input_processor.input_ids[0].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
expected_bbox = [
|
||||
[0, 0, 0, 0],
|
||||
[3, 2, 5, 1],
|
||||
[6, 7, 4, 2],
|
||||
[3, 9, 2, 4],
|
||||
[1, 1, 2, 3],
|
||||
[1, 1, 2, 3],
|
||||
[0, 0, 0, 0],
|
||||
]
|
||||
self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox)
|
||||
|
||||
@slow
|
||||
def test_processor_case_3(self):
|
||||
# case 3: token classification (training), apply_ocr=False
|
||||
|
||||
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
|
||||
tokenizers = self.get_tokenizers
|
||||
images = self.get_images
|
||||
|
||||
for tokenizer in tokenizers:
|
||||
processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
# not batched
|
||||
words = ["weirdly", "world"]
|
||||
boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
word_labels = [1, 2]
|
||||
input_processor = processor(images[0], words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "<s> weirdly world</s>"
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify labels
|
||||
expected_labels = [-100, 1, -100, 2, -100]
|
||||
self.assertListEqual(input_processor.labels.squeeze().tolist(), expected_labels)
|
||||
|
||||
# batched
|
||||
words = [["hello", "world"], ["my", "name", "is", "niels"]]
|
||||
boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
|
||||
word_labels = [[1, 2], [6, 3, 10, 2]]
|
||||
input_processor = processor(
|
||||
images, words, boxes=boxes, word_labels=word_labels, padding=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "<s> my name is niels</s>"
|
||||
decoding = processor.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
expected_bbox = [
|
||||
[0, 0, 0, 0],
|
||||
[3, 2, 5, 1],
|
||||
[6, 7, 4, 2],
|
||||
[3, 9, 2, 4],
|
||||
[1, 1, 2, 3],
|
||||
[1, 1, 2, 3],
|
||||
[0, 0, 0, 0],
|
||||
]
|
||||
self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox)
|
||||
|
||||
# verify labels
|
||||
expected_labels = [-100, 6, 3, 10, 2, -100, -100]
|
||||
self.assertListEqual(input_processor.labels[1].tolist(), expected_labels)
|
||||
|
||||
@slow
|
||||
def test_processor_case_4(self):
|
||||
# case 4: visual question answering (inference), apply_ocr=True
|
||||
|
||||
feature_extractor = LayoutLMv3FeatureExtractor()
|
||||
tokenizers = self.get_tokenizers
|
||||
images = self.get_images
|
||||
|
||||
for tokenizer in tokenizers:
|
||||
processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
# not batched
|
||||
question = "What's his name?"
|
||||
input_processor = processor(images[0], question, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
# this was obtained with Tesseract 4.1.1
|
||||
# fmt: off
|
||||
expected_decoding = "<s> What's his name?</s></s> 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President “Introductory Remarks” Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer</s>" # noqa: E231
|
||||
# fmt: on
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
questions = ["How old is he?", "what's the time"]
|
||||
input_processor = processor(
|
||||
images, questions, padding="max_length", max_length=20, truncation=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
# this was obtained with Tesseract 4.1.1
|
||||
expected_decoding = "<s> what's the time</s></s> 7 ITC Limited REPORT AND ACCOUNTS 2013 ITC</s>"
|
||||
decoding = processor.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
# fmt: off
|
||||
expected_bbox = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 45, 67, 80], [72, 56, 109, 67], [72, 56, 109, 67], [116, 56, 189, 67], [198, 59, 253, 66], [257, 59, 285, 66], [289, 59, 365, 66], [289, 59, 365, 66], [289, 59, 365, 66], [372, 59, 407, 66], [74, 136, 161, 158], [74, 136, 161, 158], [0, 0, 0, 0]] # noqa: E231
|
||||
# fmt: on
|
||||
self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox)
|
||||
|
||||
@slow
|
||||
def test_processor_case_5(self):
|
||||
# case 5: visual question answering (inference), apply_ocr=False
|
||||
|
||||
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
|
||||
tokenizers = self.get_tokenizers
|
||||
images = self.get_images
|
||||
|
||||
for tokenizer in tokenizers:
|
||||
processor = LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
# not batched
|
||||
question = "What's his name?"
|
||||
words = ["hello", "world"]
|
||||
boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
input_processor = processor(images[0], question, words, boxes, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "<s> What's his name?</s></s> hello world</s>"
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
questions = ["How old is he?", "what's the time"]
|
||||
words = [["hello", "world"], ["my", "name", "is", "niels"]]
|
||||
boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
|
||||
input_processor = processor(images, questions, words, boxes, padding=True, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "<s> How old is he?</s></s> hello world</s><pad><pad>"
|
||||
decoding = processor.decode(input_processor.input_ids[0].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
expected_decoding = "<s> what's the time</s></s> my name is niels</s>"
|
||||
decoding = processor.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
expected_bbox = [[6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3], [1, 1, 2, 3], [0, 0, 0, 0]]
|
||||
self.assertListEqual(input_processor.bbox[1].tolist()[-5:], expected_bbox)
|
2345
tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
Normal file
2345
tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -177,10 +177,11 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
# verify input_ids
|
||||
# this was obtained with Tesseract 4.1.1
|
||||
# fmt: off
|
||||
expected_decoding = "<s> 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President “Introductory Remarks” Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer</s>" # noqa: E231
|
||||
# fmt: on
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
@ -198,10 +199,11 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
# verify input_ids
|
||||
# this was obtained with Tesseract 4.1.1
|
||||
# fmt: off
|
||||
expected_decoding = "<s> 7 ITC Limited REPORT AND ACCOUNTS 2013 ITC’s Brands: An Asset for the Nation The consumer needs and aspirations they fulfil, the benefit they generate for millions across ITC’s value chains, the future-ready capabilities that support them, and the value that they create for the country, have made ITC’s brands national assets, adding to India’s competitiveness. It is ITC’s aspiration to be the No 1 FMCG player in the country, driven by its new FMCG businesses. A recent Nielsen report has highlighted that ITC's new FMCG businesses are the fastest growing among the top consumer goods companies operating in India. ITC takes justifiable pride that, along with generating economic value, these celebrated Indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. DI WILLS * ; LOVE DELIGHTFULLY SOFT SKIN? aia Ans Source: https://www.industrydocuments.ucsf.edu/docs/snbx0223</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>" # noqa: E231
|
||||
# fmt: on
|
||||
decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
|
||||
decoding = processor.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
@slow
|
||||
@ -228,7 +230,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "<s> hello world</s>"
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
@ -243,7 +245,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "<s> hello world</s><pad><pad>"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
|
||||
decoding = processor.decode(input_processor.input_ids[0].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
@ -282,7 +284,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "<s> weirdly world</s>"
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify labels
|
||||
@ -304,7 +306,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "<s> my name is niels</s>"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
|
||||
decoding = processor.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
@ -344,10 +346,11 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
# this was obtained with Tesseract 4.1.1
|
||||
# fmt: off
|
||||
expected_decoding = "<s> What's his name?</s></s> 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President “Introductory Remarks” Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer</s>" # noqa: E231
|
||||
# fmt: on
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
@ -362,8 +365,9 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
# this was obtained with Tesseract 4.1.1
|
||||
expected_decoding = "<s> what's the time</s></s> 7 ITC Limited REPORT AND ACCOUNTS 2013</s>"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
|
||||
decoding = processor.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
@ -396,7 +400,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "<s> What's his name?</s></s> hello world</s>"
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
decoding = processor.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
@ -412,11 +416,11 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "<s> How old is he?</s></s> hello world</s><pad><pad>"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
|
||||
decoding = processor.decode(input_processor.input_ids[0].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
expected_decoding = "<s> what's the time</s></s> my name is niels</s>"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
|
||||
decoding = processor.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
|
@ -33,6 +33,7 @@ src/transformers/models/gpt2/modeling_gpt2.py
|
||||
src/transformers/models/gptj/modeling_gptj.py
|
||||
src/transformers/models/hubert/modeling_hubert.py
|
||||
src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
|
||||
src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
|
||||
src/transformers/models/longformer/modeling_longformer.py
|
||||
src/transformers/models/longformer/modeling_tf_longformer.py
|
||||
src/transformers/models/marian/modeling_marian.py
|
||||
|
Loading…
Reference in New Issue
Block a user