mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add LayoutLMv2 + LayoutXLM (#12604)
* First commit * Make style * Fix dummy objects * Add Detectron2 config * Add LayoutLMv2 pooler * More improvements, add documentation * More improvements * Add model tests * Add clarification regarding image input * Improve integration test * Fix bug * Fix another bug * Fix another bug * Fix another bug * More improvements * Make more tests pass * Make more tests pass * Improve integration test * Remove gradient checkpointing and add head masking * Add integration test * Add LayoutLMv2ForSequenceClassification to the tests * Add LayoutLMv2ForQuestionAnswering * More improvements * More improvements * Small improvements * Fix _LazyModule * Fix fast tokenizer * Move sync_batch_norm to a separate method * Replace dummies by requires_backends * Move calculation of visual bounding boxes to separate method + update README * Add models to main init * First draft * More improvements * More improvements * More improvements * More improvements * More improvements * Remove is_split_into_words * More improvements * Simply tesseract - no use of pandas anymore * Add LayoutLMv2Processor * Update is_pytesseract_available * Fix bugs * Improve feature extractor * Fix bug * Add print statement * Add truncation of bounding boxes * Add tests for LayoutLMv2FeatureExtractor and LayoutLMv2Tokenizer * Improve tokenizer tests * Make more tokenizer tests pass * Make more tests pass, add integration tests * Finish integration tests * More improvements * More improvements - update API of the tokenizer * More improvements * Remove support for VQA training * Remove some files * Improve feature extractor * Improve documentation and one more tokenizer test * Make quality and small docs improvements * Add batched tests for LayoutLMv2Processor, remove fast tokenizer * Add truncation of labels * Apply suggestions from code review * Improve processor tests * Fix failing tests and add suggestion from code review * Fix tokenizer test * Add detectron2 CI job * Simplify CI job * Comment out non-detectron2 jobs and specify number of processes * Add pip install torchvision * Add durations to see which tests are slow * Fix tokenizer test and make model tests smaller * Frist draft * Use setattr * Possible fix * Proposal with configuration * First draft of fast tokenizer * More improvements * Enable fast tokenizer tests * Make more tests pass * Make more tests pass * More improvements * Addd padding to fast tokenizer * Mkae more tests pass * Make more tests pass * Make all tests pass for fast tokenizer * Make fast tokenizer support overflowing boxes and labels * Add support for overflowing_labels to slow tokenizer * Add support for fast tokenizer to the processor * Update processor tests for both slow and fast tokenizers * Add head models to model mappings * Make style & quality * Remove Detectron2 config file * Add configurable option to label all subwords * Fix test * Skip visual segment embeddings in test * Use ResNet-18 backbone in tests instead of ResNet-101 * Proposal * Re-enable all jobs on CI * Fix installation of tesseract * Fix failing test * Fix index table * Add LayoutXLM doc page, first draft of code examples * Improve documentation a lot * Update expected boxes for Tesseract 4.0.0 beta * Use offsets to create labels instead of checking if they start with ## * Update expected boxes for Tesseract 4.1.1 * Fix conflict * Make variable names cleaner, add docstring, add link to notebooks * Revert "Fix conflict" This reverts commit a9b46ce9afe47ebfcfe7b45e6a121d49e74ef2c5. * Revert to make integration test pass * Apply suggestions from @LysandreJik's review * Address @patrickvonplaten's comments * Remove fixtures DocVQA in favor of dataset on the hub Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
parent
439e7abd2d
commit
b6ddb08a66
@ -798,6 +798,44 @@ jobs:
|
||||
- run: pip install requests
|
||||
- run: python ./utils/link_tester.py
|
||||
|
||||
run_tests_layoutlmv2:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.7
|
||||
environment:
|
||||
OMP_NUM_THREADS: 1
|
||||
TRANSFORMERS_IS_CI: yes
|
||||
resource_class: xlarge
|
||||
parallelism: 1
|
||||
steps:
|
||||
- checkout
|
||||
- restore_cache:
|
||||
keys:
|
||||
- v0.4-torch-{{ checksum "setup.py" }}
|
||||
- v0.4-{{ checksum "setup.py" }}
|
||||
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[torch,testing,vision]
|
||||
- run: pip install torchvision
|
||||
- run: python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
|
||||
- run: sudo apt install tesseract-ocr
|
||||
- run: pip install pytesseract
|
||||
- save_cache:
|
||||
key: v0.4-torch-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
- '~/.cache/pip'
|
||||
- run: python utils/tests_fetcher.py | tee test_preparation.txt
|
||||
- store_artifacts:
|
||||
path: ~/transformers/test_preparation.txt
|
||||
- run: |
|
||||
if [ -f test_list.txt ]; then
|
||||
python -m pytest -n 1 tests/*layoutlmv2* --dist=loadfile -s --make-reports=tests_layoutlmv2 --durations=100
|
||||
fi
|
||||
- store_artifacts:
|
||||
path: ~/transformers/tests_output.txt
|
||||
- store_artifacts:
|
||||
path: ~/transformers/reports
|
||||
|
||||
# TPU JOBS
|
||||
run_examples_tpu:
|
||||
docker:
|
||||
@ -852,6 +890,7 @@ workflows:
|
||||
- run_tests_onnxruntime
|
||||
- run_tests_hub
|
||||
- build_doc
|
||||
- run_tests_layoutlmv2
|
||||
- deploy_doc: *workflow_filters
|
||||
nightly:
|
||||
triggers:
|
||||
|
@ -244,6 +244,8 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[Hubert](https://huggingface.co/transformers/model_doc/hubert.html)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed.
|
||||
1. **[I-BERT](https://huggingface.co/transformers/model_doc/ibert.html)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer
|
||||
1. **[LayoutLM](https://huggingface.co/transformers/model_doc/layoutlm.html)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
|
||||
1. **[LayoutLMv2](https://huggingface.co/transformers/model_doc/layoutlmv2.html)** (from Microsoft Research Asia) released with the paper [LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding](https://arxiv.org/abs/2012.14740) by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou.
|
||||
1. **[LayoutXLM](https://huggingface.co/transformers/model_doc/layoutlmv2.html)** (from Microsoft Research Asia) released with the paper [LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding](https://arxiv.org/abs/2104.08836) by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei.
|
||||
1. **[LED](https://huggingface.co/transformers/model_doc/led.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[Longformer](https://huggingface.co/transformers/model_doc/longformer.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[LUKE](https://huggingface.co/transformers/model_doc/luke.html)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
|
||||
|
@ -202,99 +202,106 @@ Supported models
|
||||
34. :doc:`LayoutLM <model_doc/layoutlm>` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training
|
||||
of Text and Layout for Document Image Understanding <https://arxiv.org/abs/1912.13318>`__ by Yiheng Xu, Minghao Li,
|
||||
Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
|
||||
35. :doc:`LED <model_doc/led>` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer
|
||||
35. :doc:`LayoutLMv2 <model_doc/layoutlmv2>` (from Microsoft Research Asia) released with the paper `LayoutLMv2:
|
||||
Multi-modal Pre-training for Visually-Rich Document Understanding <https://arxiv.org/abs/2012.14740>`__ by Yang Xu,
|
||||
Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min
|
||||
Zhang, Lidong Zhou.
|
||||
36. :doc:`LayoutXLM <model_doc/layoutlmv2>` (from Microsoft Research Asia) released with the paper `LayoutXLM:
|
||||
Multimodal Pre-training for Multilingual Visually-rich Document Understanding <https://arxiv.org/abs/2104.08836>`__
|
||||
by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei.
|
||||
37. :doc:`LED <model_doc/led>` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer
|
||||
<https://arxiv.org/abs/2004.05150>`__ by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
36. :doc:`Longformer <model_doc/longformer>` (from AllenAI) released with the paper `Longformer: The Long-Document
|
||||
38. :doc:`Longformer <model_doc/longformer>` (from AllenAI) released with the paper `Longformer: The Long-Document
|
||||
Transformer <https://arxiv.org/abs/2004.05150>`__ by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
37. :doc:`LUKE <model_doc/luke>` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity
|
||||
39. :doc:`LUKE <model_doc/luke>` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity
|
||||
Representations with Entity-aware Self-attention <https://arxiv.org/abs/2010.01057>`__ by Ikuya Yamada, Akari Asai,
|
||||
Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
|
||||
38. :doc:`LXMERT <model_doc/lxmert>` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality
|
||||
40. :doc:`LXMERT <model_doc/lxmert>` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality
|
||||
Encoder Representations from Transformers for Open-Domain Question Answering <https://arxiv.org/abs/1908.07490>`__
|
||||
by Hao Tan and Mohit Bansal.
|
||||
39. :doc:`M2M100 <model_doc/m2m_100>` (from Facebook) released with the paper `Beyond English-Centric Multilingual
|
||||
41. :doc:`M2M100 <model_doc/m2m_100>` (from Facebook) released with the paper `Beyond English-Centric Multilingual
|
||||
Machine Translation <https://arxiv.org/abs/2010.11125>`__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi
|
||||
Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman
|
||||
Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
|
||||
40. :doc:`MarianMT <model_doc/marian>` Machine translation models trained using `OPUS <http://opus.nlpl.eu/>`__ data by
|
||||
42. :doc:`MarianMT <model_doc/marian>` Machine translation models trained using `OPUS <http://opus.nlpl.eu/>`__ data by
|
||||
Jörg Tiedemann. The `Marian Framework <https://marian-nmt.github.io/>`__ is being developed by the Microsoft
|
||||
Translator Team.
|
||||
41. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
|
||||
43. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
|
||||
Neural Machine Translation <https://arxiv.org/abs/2001.08210>`__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li,
|
||||
Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
|
||||
42. :doc:`MBart-50 <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Translation with Extensible
|
||||
44. :doc:`MBart-50 <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Translation with Extensible
|
||||
Multilingual Pretraining and Finetuning <https://arxiv.org/abs/2008.00401>`__ by Yuqing Tang, Chau Tran, Xian Li,
|
||||
Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
|
||||
43. :doc:`Megatron-BERT <model_doc/megatron_bert>` (from NVIDIA) released with the paper `Megatron-LM: Training
|
||||
45. :doc:`Megatron-BERT <model_doc/megatron_bert>` (from NVIDIA) released with the paper `Megatron-LM: Training
|
||||
Multi-Billion Parameter Language Models Using Model Parallelism <https://arxiv.org/abs/1909.08053>`__ by Mohammad
|
||||
Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
|
||||
44. :doc:`Megatron-GPT2 <model_doc/megatron_gpt2>` (from NVIDIA) released with the paper `Megatron-LM: Training
|
||||
46. :doc:`Megatron-GPT2 <model_doc/megatron_gpt2>` (from NVIDIA) released with the paper `Megatron-LM: Training
|
||||
Multi-Billion Parameter Language Models Using Model Parallelism <https://arxiv.org/abs/1909.08053>`__ by Mohammad
|
||||
Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
|
||||
45. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
|
||||
47. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
|
||||
Pre-training for Language Understanding <https://arxiv.org/abs/2004.09297>`__ by Kaitao Song, Xu Tan, Tao Qin,
|
||||
Jianfeng Lu, Tie-Yan Liu.
|
||||
46. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
|
||||
48. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
|
||||
text-to-text transformer <https://arxiv.org/abs/2010.11934>`__ by Linting Xue, Noah Constant, Adam Roberts, Mihir
|
||||
Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
47. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
|
||||
49. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
|
||||
Gap-sentences for Abstractive Summarization <https://arxiv.org/abs/1912.08777>`__> by Jingqing Zhang, Yao Zhao,
|
||||
Mohammad Saleh and Peter J. Liu.
|
||||
48. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
|
||||
50. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
|
||||
Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan, Weizhen Qi,
|
||||
Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
49. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
|
||||
51. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
|
||||
Transformer <https://arxiv.org/abs/2001.04451>`__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
||||
50. :doc:`RemBERT <model_doc/rembert>` (from Google Research) released with the paper `Rethinking embedding coupling in
|
||||
52. :doc:`RemBERT <model_doc/rembert>` (from Google Research) released with the paper `Rethinking embedding coupling in
|
||||
pre-trained language models <https://arxiv.org/pdf/2010.12821.pdf>`__ by Hyung Won Chung, Thibault Févry, Henry
|
||||
Tsai, M. Johnson, Sebastian Ruder.
|
||||
51. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
|
||||
53. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
|
||||
Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar
|
||||
Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
52. :doc:`RoFormer <model_doc/roformer>` (from ZhuiyiTechnology), released together with the paper a `RoFormer:
|
||||
54. :doc:`RoFormer <model_doc/roformer>` (from ZhuiyiTechnology), released together with the paper a `RoFormer:
|
||||
Enhanced Transformer with Rotary Position Embedding <https://arxiv.org/pdf/2104.09864v1.pdf>`__ by Jianlin Su and
|
||||
Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
53. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
|
||||
55. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
|
||||
`fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun
|
||||
Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
|
||||
54. `Splinter <https://huggingface.co/transformers/master/model_doc/splinter.html>`__ (from Tel Aviv University),
|
||||
56. `Splinter <https://huggingface.co/transformers/master/model_doc/splinter.html>`__ (from Tel Aviv University),
|
||||
released together with the paper `Few-Shot Question Answering by Pretraining Span Selection
|
||||
<https://arxiv.org/abs/2101.00438>`__ by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy.
|
||||
55. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||
57. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
|
||||
Krishna, and Kurt W. Keutzer.
|
||||
56. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
58. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
|
||||
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
57. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
59. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
|
||||
Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
58. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
60. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
||||
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
59. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
|
||||
61. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
|
||||
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
|
||||
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
|
||||
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
60. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
|
||||
62. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
|
||||
Performant Baseline for Vision and Language <https://arxiv.org/pdf/1908.03557>`__ by Liunian Harold Li, Mark
|
||||
Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
|
||||
61. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
63. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
|
||||
Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
62. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
64. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
||||
63. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
65. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
||||
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
64. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
66. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
||||
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
||||
Zettlemoyer and Veselin Stoyanov.
|
||||
65. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
67. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
66. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
68. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
|
||||
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
|
||||
@ -372,6 +379,8 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LayoutLM | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LayoutLMv2 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
@ -550,6 +559,8 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
model_doc/herbert
|
||||
model_doc/ibert
|
||||
model_doc/layoutlm
|
||||
model_doc/layoutlmv2
|
||||
model_doc/layoutxlm
|
||||
model_doc/led
|
||||
model_doc/longformer
|
||||
model_doc/luke
|
||||
|
314
docs/source/model_doc/layoutlmv2.rst
Normal file
314
docs/source/model_doc/layoutlmv2.rst
Normal file
@ -0,0 +1,314 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
LayoutLMV2
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The LayoutLMV2 model was proposed in `LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding
|
||||
<https://arxiv.org/abs/2012.14740>`__ by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu,
|
||||
Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. LayoutLMV2 improves `LayoutLM
|
||||
<https://huggingface.co/transformers/model_doc/layoutlm.html>`__ to obtain state-of-the-art results across several
|
||||
document image understanding benchmarks:
|
||||
|
||||
- information extraction from scanned documents: the `FUNSD <https://guillaumejaume.github.io/FUNSD/>`__ dataset (a
|
||||
collection of 199 annotated forms comprising more than 30,000 words), the `CORD <https://github.com/clovaai/cord>`__
|
||||
dataset (a collection of 800 receipts for training, 100 for validation and 100 for testing), the `SROIE
|
||||
<https://rrc.cvc.uab.es/?ch=13>`__ dataset (a collection of 626 receipts for training and 347 receipts for testing)
|
||||
and the `Kleister-NDA <https://github.com/applicaai/kleister-nda>`__ dataset (a collection of non-disclosure
|
||||
agreements from the EDGAR database, including 254 documents for training, 83 documents for validation, and 203
|
||||
documents for testing).
|
||||
- document image classification: the `RVL-CDIP <https://www.cs.cmu.edu/~aharley/rvl-cdip/>`__ dataset (a collection of
|
||||
400,000 images belonging to one of 16 classes).
|
||||
- document visual question answering: the `DocVQA <https://arxiv.org/abs/2007.00398>`__ dataset (a collection of 50,000
|
||||
questions defined on 12,000+ document images).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Pre-training of text and layout has proved effective in a variety of visually-rich document understanding tasks due to
|
||||
its effective model architecture and the advantage of large-scale unlabeled scanned/digital-born documents. In this
|
||||
paper, we present LayoutLMv2 by pre-training text, layout and image in a multi-modal framework, where new model
|
||||
architectures and pre-training tasks are leveraged. Specifically, LayoutLMv2 not only uses the existing masked
|
||||
visual-language modeling task but also the new text-image alignment and text-image matching tasks in the pre-training
|
||||
stage, where cross-modality interaction is better learned. Meanwhile, it also integrates a spatial-aware self-attention
|
||||
mechanism into the Transformer architecture, so that the model can fully understand the relative positional
|
||||
relationship among different text blocks. Experiment results show that LayoutLMv2 outperforms strong baselines and
|
||||
achieves new state-of-the-art results on a wide variety of downstream visually-rich document understanding tasks,
|
||||
including FUNSD (0.7895 -> 0.8420), CORD (0.9493 -> 0.9601), SROIE (0.9524 -> 0.9781), Kleister-NDA (0.834 -> 0.852),
|
||||
RVL-CDIP (0.9443 -> 0.9564), and DocVQA (0.7295 -> 0.8672). The pre-trained LayoutLMv2 model is publicly available at
|
||||
this https URL.*
|
||||
|
||||
Tips:
|
||||
|
||||
- The main difference between LayoutLMv1 and LayoutLMv2 is that the latter incorporates visual embeddings during
|
||||
pre-training (while LayoutLMv1 only adds visual embeddings during fine-tuning).
|
||||
- LayoutLMv2 adds both a relative 1D attention bias as well as a spatial 2D attention bias to the attention scores in
|
||||
the self-attention layers. Details can be found on page 5 of the `paper <https://arxiv.org/abs/2012.14740>`__.
|
||||
- Demo notebooks on how to use the LayoutLMv2 model on RVL-CDIP, FUNSD, DocVQA, CORD can be found `here
|
||||
<https://github.com/NielsRogge/Transformers-Tutorials>`__.
|
||||
- LayoutLMv2 uses Facebook AI's `Detectron2 <https://github.com/facebookresearch/detectron2/>`__ package for its visual
|
||||
backbone. See `this link <https://detectron2.readthedocs.io/en/latest/tutorials/install.html>`__ for installation
|
||||
instructions.
|
||||
- In addition to :obj:`input_ids`, :meth:`~transformer.LayoutLMv2Model.forward` expects 2 additional inputs, namely
|
||||
:obj:`image` and :obj:`bbox`. The :obj:`image` input corresponds to the original document image in which the text
|
||||
tokens occur. The model expects each document image to be of size 224x224. This means that if you have a batch of
|
||||
document images, :obj:`image` should be a tensor of shape (batch_size, 3, 224, 224). This can be either a
|
||||
:obj:`torch.Tensor` or a :obj:`Detectron2.structures.ImageList`. You don't need to normalize the channels, as this is
|
||||
done by the model. Important to note is that the visual backbone expects BGR channels instead of RGB, as all models
|
||||
in Detectron2 are pre-trained using the BGR format. The :obj:`bbox` input are the bounding boxes (i.e. 2D-positions)
|
||||
of the input text tokens. This is identical to :class:`~transformer.LayoutLMModel`. These can be obtained using an
|
||||
external OCR engine such as Google's `Tesseract <https://github.com/tesseract-ocr/tesseract>`__ (there's a `Python
|
||||
wrapper <https://pypi.org/project/pytesseract/>`__ available). Each bounding box should be in (x0, y0, x1, y1)
|
||||
format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, y1)
|
||||
represents the position of the lower right corner. Note that one first needs to normalize the bounding boxes to be on
|
||||
a 0-1000 scale. To normalize, you can use the following function:
|
||||
|
||||
.. code-block::
|
||||
|
||||
def normalize_bbox(bbox, width, height):
|
||||
return [
|
||||
int(1000 * (bbox[0] / width)),
|
||||
int(1000 * (bbox[1] / height)),
|
||||
int(1000 * (bbox[2] / width)),
|
||||
int(1000 * (bbox[3] / height)),
|
||||
]
|
||||
|
||||
Here, :obj:`width` and :obj:`height` correspond to the width and height of the original document in which the token
|
||||
occurs (before resizing the image). Those can be obtained using the Python Image Library (PIL) library for example, as
|
||||
follows:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open("name_of_your_document - can be a png file, pdf, etc.")
|
||||
|
||||
width, height = image.size
|
||||
|
||||
However, this model includes a brand new :class:`~transformer.LayoutLMv2Processor` which can be used to directly
|
||||
prepare data for the model (including applying OCR under the hood). More information can be found in the "Usage"
|
||||
section below.
|
||||
|
||||
- Internally, :class:`~transformer.LayoutLMv2Model` will send the :obj:`image` input through its visual backbone to
|
||||
obtain a lower-resolution feature map, whose shape is equal to the :obj:`image_feature_pool_shape` attribute of
|
||||
:class:`~transformer.LayoutLMv2Config`. This feature map is then flattened to obtain a sequence of image tokens. As
|
||||
the size of the feature map is 7x7 by default, one obtains 49 image tokens. These are then concatenated with the text
|
||||
tokens, and send through the Transformer encoder. This means that the last hidden states of the model will have a
|
||||
length of 512 + 49 = 561, if you pad the text tokens up to the max length. More generally, the last hidden states
|
||||
will have a shape of :obj:`seq_length` + :obj:`image_feature_pool_shape[0]` *
|
||||
:obj:`config.image_feature_pool_shape[1]`.
|
||||
- When calling :meth:`~transformer.LayoutLMv2Model.from_pretrained`, a warning will be printed with a long list of
|
||||
parameter names that are not initialized. This is not a problem, as these parameters are batch normalization
|
||||
statistics, which are going to have values when fine-tuning on a custom dataset.
|
||||
- If you want to train the model in a distributed environment, make sure to call :meth:`synchronize_batch_norm` on the
|
||||
model in order to properly synchronize the batch normalization layers of the visual backbone.
|
||||
|
||||
In addition, there's LayoutXLM, which is a multilingual version of LayoutLMv2. More information can be found on
|
||||
:doc:`LayoutXLM's documentation page <layoutxlm>`.
|
||||
|
||||
Usage: LayoutLMv2Processor
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The easiest way to prepare data for the model is to use :class:`~transformer.LayoutLMv2Processor`, which internally
|
||||
combines a feature extractor (:class:`~transformer.LayoutLMv2FeatureExtractor`) and a tokenizer
|
||||
(:class:`~transformer.LayoutLMv2Tokenizer` or :class:`~transformer.LayoutLMv2TokenizerFast`). The feature extractor
|
||||
handles the image modality, while the tokenizer handles the text modality. A processor combines both, which is ideal
|
||||
for a multi-modal model like LayoutLMv2. Note that you can still use both separately, if you only want to handle one
|
||||
modality.
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import LayoutLMv2FeatureExtractor, LayoutLMv2TokenizerFast, LayoutLMv2Processor
|
||||
|
||||
feature_extractor = LayoutLMv2FeatureExtractor() # apply_ocr is set to True by default
|
||||
tokenizer = LayoutLMv2TokenizerFast.from_pretrained("microsoft/layoutlmv2-base-uncased")
|
||||
processor = LayoutLMv2Processor(feature_extractor, tokenizer)
|
||||
|
||||
In short, one can provide a document image (and possibly additional data) to :class:`~transformer.LayoutLMv2Processor`,
|
||||
and it will create the inputs expected by the model. Internally, the processor first uses
|
||||
:class:`~transformer.LayoutLMv2FeatureExtractor` to apply OCR on the image to get a list of words and normalized
|
||||
bounding boxes, as well to resize the image to a given size in order to get the :obj:`image` input. The words and
|
||||
normalized bounding boxes are then provided to :class:`~transformer.LayoutLMv2Tokenizer` or
|
||||
:class:`~transformer.LayoutLMv2TokenizerFast`, which converts them to token-level :obj:`input_ids`,
|
||||
:obj:`attention_mask`, :obj:`token_type_ids`, :obj:`bbox`. Optionally, one can provide word labels to the processor,
|
||||
which are turned into token-level :obj:`labels`.
|
||||
|
||||
:class:`~transformer.LayoutLMv2Processor` uses `PyTesseract <https://pypi.org/project/pytesseract/>`__, a Python
|
||||
wrapper around Google's Tesseract OCR engine, under the hood. Note that you can still use your own OCR engine of
|
||||
choice, and provide the words and normalized boxes yourself. This requires initializing
|
||||
:class:`~transformer.LayoutLMv2FeatureExtractor` with :obj:`apply_ocr` set to :obj:`False`.
|
||||
|
||||
In total, there are 5 use cases that are supported by the processor. Below, we list them all. Note that each of these
|
||||
use cases work for both batched and non-batched inputs (we illustrate them for non-batched inputs).
|
||||
|
||||
**Use case 1: document image classification (training, inference) + token classification (inference), apply_ocr =
|
||||
True**
|
||||
|
||||
This is the simplest case, in which the processor (actually the feature extractor) will perform OCR on the image to get
|
||||
the words and normalized bounding boxes.
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import LayoutLMv2Processor
|
||||
from PIL import Image
|
||||
|
||||
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
|
||||
|
||||
image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
|
||||
encoding = processor(image, return_tensors="pt") # you can also add all tokenizer parameters here such as padding, truncation
|
||||
print(encoding.keys())
|
||||
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'bbox', 'image'])
|
||||
|
||||
**Use case 2: document image classification (training, inference) + token classification (inference), apply_ocr=False**
|
||||
|
||||
In case one wants to do OCR themselves, one can initialize the feature extractor with :obj:`apply_ocr` set to
|
||||
:obj:`False`. In that case, one should provide the words and corresponding (normalized) bounding boxes themselves to
|
||||
the processor.
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import LayoutLMv2Processor
|
||||
from PIL import Image
|
||||
|
||||
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
|
||||
|
||||
image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
|
||||
words = ["hello", "world"]
|
||||
boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] # make sure to normalize your bounding boxes
|
||||
encoding = processor(image, words, boxes=boxes, return_tensors="pt")
|
||||
print(encoding.keys())
|
||||
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'bbox', 'image'])
|
||||
|
||||
**Use case 3: token classification (training), apply_ocr=False**
|
||||
|
||||
For token classification tasks (such as FUNSD, CORD, SROIE, Kleister-NDA), one can also provide the corresponding word
|
||||
labels in order to train a model. The processor will then convert these into token-level :obj:`labels`. By default, it
|
||||
will only label the first wordpiece of a word, and label the remaining wordpieces with -100, which is the
|
||||
:obj:`ignore_index` of PyTorch's CrossEntropyLoss. In case you want all wordpieces of a word to be labeled, you can
|
||||
initialize the tokenizer with :obj:`only_label_first_subword` set to :obj:`False`.
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import LayoutLMv2Processor
|
||||
from PIL import Image
|
||||
|
||||
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
|
||||
|
||||
image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
|
||||
words = ["hello", "world"]
|
||||
boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] # make sure to normalize your bounding boxes
|
||||
word_labels = [1, 2]
|
||||
encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
|
||||
print(encoding.keys())
|
||||
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'bbox', 'labels', 'image'])
|
||||
|
||||
**Use case 4: visual question answering (inference), apply_ocr=True**
|
||||
|
||||
For visual question answering tasks (such as DocVQA), you can provide a question to the processor. By default, the
|
||||
processor will apply OCR on the image, and create [CLS] question tokens [SEP] word tokens [SEP].
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import LayoutLMv2Processor
|
||||
from PIL import Image
|
||||
|
||||
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")
|
||||
|
||||
image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
|
||||
question = "What's his name?"
|
||||
encoding = processor(image, question, return_tensors="pt")
|
||||
print(encoding.keys())
|
||||
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'bbox', 'image'])
|
||||
|
||||
**Use case 5: visual question answering (inference), apply_ocr=False**
|
||||
|
||||
For visual question answering tasks (such as DocVQA), you can provide a question to the processor. If you want to
|
||||
perform OCR yourself, you can provide your own words and (normalized) bounding boxes to the processor.
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import LayoutLMv2Processor
|
||||
from PIL import Image
|
||||
|
||||
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
|
||||
|
||||
image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
|
||||
question = "What's his name?"
|
||||
words = ["hello", "world"]
|
||||
boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] # make sure to normalize your bounding boxes
|
||||
encoding = processor(image, question, words, boxes=boxes, return_tensors="pt")
|
||||
print(encoding.keys())
|
||||
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'bbox', 'image'])
|
||||
|
||||
LayoutLMv2Config
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LayoutLMv2Config
|
||||
:members:
|
||||
|
||||
|
||||
LayoutLMv2FeatureExtractor
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LayoutLMv2FeatureExtractor
|
||||
:members: __call__
|
||||
|
||||
|
||||
LayoutLMv2Tokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LayoutLMv2Tokenizer
|
||||
:members: __call__, save_vocabulary
|
||||
|
||||
|
||||
LayoutLMv2TokenizerFast
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LayoutLMv2TokenizerFast
|
||||
:members: __call__
|
||||
|
||||
|
||||
LayoutLMv2Processor
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LayoutLMv2Processor
|
||||
:members: __call__
|
||||
|
||||
|
||||
LayoutLMv2Model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LayoutLMv2Model
|
||||
:members: forward
|
||||
|
||||
|
||||
LayoutLMv2ForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LayoutLMv2ForSequenceClassification
|
||||
:members:
|
||||
|
||||
|
||||
LayoutLMv2ForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LayoutLMv2ForTokenClassification
|
||||
:members:
|
||||
|
||||
|
||||
LayoutLMv2ForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LayoutLMv2ForQuestionAnswering
|
||||
:members:
|
47
docs/source/model_doc/layoutxlm.rst
Normal file
47
docs/source/model_doc/layoutxlm.rst
Normal file
@ -0,0 +1,47 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
LayoutXLM
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
LayoutXLM was proposed in `LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding
|
||||
<https://arxiv.org/abs/2104.08836>`__ by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha
|
||||
Zhang, Furu Wei. It's a multilingual extension of the `LayoutLMv2 model <https://arxiv.org/abs/2012.14740>`__ trained
|
||||
on 53 languages.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Multimodal pre-training with text, layout, and image has achieved SOTA performance for visually-rich document
|
||||
understanding tasks recently, which demonstrates the great potential for joint learning across different modalities. In
|
||||
this paper, we present LayoutXLM, a multimodal pre-trained model for multilingual document understanding, which aims to
|
||||
bridge the language barriers for visually-rich document understanding. To accurately evaluate LayoutXLM, we also
|
||||
introduce a multilingual form understanding benchmark dataset named XFUN, which includes form understanding samples in
|
||||
7 languages (Chinese, Japanese, Spanish, French, Italian, German, Portuguese), and key-value pairs are manually labeled
|
||||
for each language. Experiment results show that the LayoutXLM model has significantly outperformed the existing SOTA
|
||||
cross-lingual pre-trained models on the XFUN dataset.*
|
||||
|
||||
One can directly plug in the weights of LayoutXLM into a LayoutLMv2 model, like so:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import LayoutLMv2Model
|
||||
|
||||
model = LayoutLMv2Model.from_pretrained('microsoft/layoutxlm-base')
|
||||
|
||||
As LayoutXLM's architecture is equivalent to that of LayoutLMv2, one can refer to :doc:`LayoutLMv2's documentation page
|
||||
<layoutlmv2>` for all tips, code examples and notebooks.
|
||||
|
||||
This model was contributed by `nielsr <https://huggingface.co/nielsr>`__. The original code can be found `here
|
||||
<https://github.com/microsoft/unilm>`__.
|
@ -206,6 +206,13 @@ _import_structure = {
|
||||
"models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"],
|
||||
"models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
|
||||
"models.layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMTokenizer"],
|
||||
"models.layoutlmv2": [
|
||||
"LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"LayoutLMv2Config",
|
||||
"LayoutLMv2FeatureExtractor",
|
||||
"LayoutLMv2Processor",
|
||||
"LayoutLMv2Tokenizer",
|
||||
],
|
||||
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
|
||||
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
|
||||
"models.luke": ["LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP", "LukeConfig", "LukeTokenizer"],
|
||||
@ -356,6 +363,7 @@ if is_tokenizers_available():
|
||||
_import_structure["models.gpt2"].append("GPT2TokenizerFast")
|
||||
_import_structure["models.herbert"].append("HerbertTokenizerFast")
|
||||
_import_structure["models.layoutlm"].append("LayoutLMTokenizerFast")
|
||||
_import_structure["models.layoutlmv2"].append("LayoutLMv2TokenizerFast")
|
||||
_import_structure["models.led"].append("LEDTokenizerFast")
|
||||
_import_structure["models.longformer"].append("LongformerTokenizerFast")
|
||||
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
|
||||
@ -396,7 +404,6 @@ else:
|
||||
# Speech-specific objects
|
||||
if is_speech_available():
|
||||
_import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")
|
||||
|
||||
else:
|
||||
from .utils import dummy_speech_objects
|
||||
|
||||
@ -421,6 +428,8 @@ if is_vision_available():
|
||||
_import_structure["models.clip"].append("CLIPProcessor")
|
||||
_import_structure["models.deit"].append("DeiTFeatureExtractor")
|
||||
_import_structure["models.detr"].append("DetrFeatureExtractor")
|
||||
_import_structure["models.layoutlmv2"].append("LayoutLMv2FeatureExtractor")
|
||||
_import_structure["models.layoutlmv2"].append("LayoutLMv2Processor")
|
||||
_import_structure["models.vit"].append("ViTFeatureExtractor")
|
||||
else:
|
||||
from .utils import dummy_vision_objects
|
||||
@ -845,6 +854,16 @@ if is_torch_available():
|
||||
"LayoutLMPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.layoutlmv2"].extend(
|
||||
[
|
||||
"LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"LayoutLMv2ForQuestionAnswering",
|
||||
"LayoutLMv2ForSequenceClassification",
|
||||
"LayoutLMv2ForTokenClassification",
|
||||
"LayoutLMv2Model",
|
||||
"LayoutLMv2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.led"].extend(
|
||||
[
|
||||
"LED_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -1905,6 +1924,13 @@ if TYPE_CHECKING:
|
||||
from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
|
||||
from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
|
||||
from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer
|
||||
from .models.layoutlmv2 import (
|
||||
LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
LayoutLMv2Config,
|
||||
LayoutLMv2FeatureExtractor,
|
||||
LayoutLMv2Processor,
|
||||
LayoutLMv2Tokenizer,
|
||||
)
|
||||
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
|
||||
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
|
||||
from .models.luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig, LukeTokenizer
|
||||
@ -2045,6 +2071,7 @@ if TYPE_CHECKING:
|
||||
from .models.gpt2 import GPT2TokenizerFast
|
||||
from .models.herbert import HerbertTokenizerFast
|
||||
from .models.layoutlm import LayoutLMTokenizerFast
|
||||
from .models.layoutlmv2 import LayoutLMv2TokenizerFast
|
||||
from .models.led import LEDTokenizerFast
|
||||
from .models.longformer import LongformerTokenizerFast
|
||||
from .models.lxmert import LxmertTokenizerFast
|
||||
@ -2077,7 +2104,6 @@ if TYPE_CHECKING:
|
||||
|
||||
if is_speech_available():
|
||||
from .models.speech_to_text import Speech2TextFeatureExtractor
|
||||
|
||||
else:
|
||||
from .utils.dummy_speech_objects import *
|
||||
|
||||
@ -2092,6 +2118,7 @@ if TYPE_CHECKING:
|
||||
from .models.clip import CLIPFeatureExtractor, CLIPProcessor
|
||||
from .models.deit import DeiTFeatureExtractor
|
||||
from .models.detr import DetrFeatureExtractor
|
||||
from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2Processor
|
||||
from .models.vit import ViTFeatureExtractor
|
||||
else:
|
||||
from .utils.dummy_vision_objects import *
|
||||
@ -2448,6 +2475,14 @@ if TYPE_CHECKING:
|
||||
LayoutLMModel,
|
||||
LayoutLMPreTrainedModel,
|
||||
)
|
||||
from .models.layoutlmv2 import (
|
||||
LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LayoutLMv2ForQuestionAnswering,
|
||||
LayoutLMv2ForSequenceClassification,
|
||||
LayoutLMv2ForTokenClassification,
|
||||
LayoutLMv2Model,
|
||||
LayoutLMv2PreTrainedModel,
|
||||
)
|
||||
from .models.led import (
|
||||
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LEDForConditionalGeneration,
|
||||
|
@ -842,6 +842,45 @@ class CLIPConverter(Converter):
|
||||
return tokenizer
|
||||
|
||||
|
||||
class LayoutLMv2Converter(Converter):
|
||||
def converted(self) -> Tokenizer:
|
||||
vocab = self.original_tokenizer.vocab
|
||||
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
||||
|
||||
tokenize_chinese_chars = False
|
||||
strip_accents = False
|
||||
do_lower_case = True
|
||||
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
||||
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
||||
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
||||
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
||||
|
||||
tokenizer.normalizer = normalizers.BertNormalizer(
|
||||
clean_text=True,
|
||||
handle_chinese_chars=tokenize_chinese_chars,
|
||||
strip_accents=strip_accents,
|
||||
lowercase=do_lower_case,
|
||||
)
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
||||
|
||||
cls = str(self.original_tokenizer.cls_token)
|
||||
sep = str(self.original_tokenizer.sep_token)
|
||||
cls_token_id = self.original_tokenizer.cls_token_id
|
||||
sep_token_id = self.original_tokenizer.sep_token_id
|
||||
|
||||
tokenizer.post_processor = processors.TemplateProcessing(
|
||||
single=f"{cls}:0 $A:0 {sep}:0",
|
||||
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
|
||||
special_tokens=[
|
||||
(cls, cls_token_id),
|
||||
(sep, sep_token_id),
|
||||
],
|
||||
)
|
||||
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
SLOW_TO_FAST_CONVERTERS = {
|
||||
"AlbertTokenizer": AlbertConverter,
|
||||
"BartTokenizer": RobertaConverter,
|
||||
@ -861,6 +900,7 @@ SLOW_TO_FAST_CONVERTERS = {
|
||||
"GPT2Tokenizer": GPT2Converter,
|
||||
"HerbertTokenizer": HerbertConverter,
|
||||
"LayoutLMTokenizer": BertConverter,
|
||||
"LayoutLMv2Tokenizer": BertConverter,
|
||||
"LongformerTokenizer": RobertaConverter,
|
||||
"LEDTokenizer": RobertaConverter,
|
||||
"LxmertTokenizer": BertConverter,
|
||||
|
@ -137,6 +137,14 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_datasets_available = False
|
||||
|
||||
|
||||
_detectron2_available = importlib.util.find_spec("detectron2") is not None
|
||||
try:
|
||||
_detectron2_version = importlib_metadata.version("detectron2")
|
||||
logger.debug(f"Successfully imported detectron2 version {_detectron2_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_detectron2_available = False
|
||||
|
||||
|
||||
_faiss_available = importlib.util.find_spec("faiss") is not None
|
||||
try:
|
||||
_faiss_version = importlib_metadata.version("faiss")
|
||||
@ -352,6 +360,10 @@ def is_datasets_available():
|
||||
return _datasets_available
|
||||
|
||||
|
||||
def is_detectron2_available():
|
||||
return _detectron2_available
|
||||
|
||||
|
||||
def is_rjieba_available():
|
||||
return importlib.util.find_spec("rjieba") is not None
|
||||
|
||||
@ -400,6 +412,10 @@ def is_vision_available():
|
||||
return importlib.util.find_spec("PIL") is not None
|
||||
|
||||
|
||||
def is_pytesseract_available():
|
||||
return importlib.util.find_spec("pytesseract") is not None
|
||||
|
||||
|
||||
def is_in_notebook():
|
||||
try:
|
||||
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
||||
@ -576,6 +592,14 @@ installation page: https://www.tensorflow.org/install and follow the ones that m
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
DETECTRON2_IMPORT_ERROR = """
|
||||
{0} requires the detectron2 library but it was not found in your environment. Checkout the instructions on the
|
||||
installation page: https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md and follow the ones
|
||||
that match your environment.
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
@ -623,13 +647,22 @@ VISION_IMPORT_ERROR = """
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
PYTESSERACT_IMPORT_ERROR = """
|
||||
{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
|
||||
`pip install pytesseract`
|
||||
"""
|
||||
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
|
||||
("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
|
||||
("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
|
||||
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
||||
("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
|
||||
("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
|
||||
("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
|
||||
("scatter", (is_scatter_available, SCATTER_IMPORT_ERROR)),
|
||||
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
|
||||
("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
|
||||
|
@ -54,6 +54,7 @@ from . import (
|
||||
hubert,
|
||||
ibert,
|
||||
layoutlm,
|
||||
layoutlmv2,
|
||||
led,
|
||||
longformer,
|
||||
luke,
|
||||
|
@ -26,6 +26,7 @@ from ...file_utils import CONFIG_NAME
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
("layoutlmv2", "LayoutLMv2Config"),
|
||||
("beit", "BeitConfig"),
|
||||
("rembert", "RemBertConfig"),
|
||||
("visual_bert", "VisualBertConfig"),
|
||||
@ -95,6 +96,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add archive maps here
|
||||
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("visual_bert", "VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
@ -158,6 +160,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
# Add full (and cased) model names here
|
||||
("beit", "BeiT"),
|
||||
("rembert", "RemBERT"),
|
||||
("layoutlmv2", "LayoutLMv2"),
|
||||
("visual_bert", "VisualBert"),
|
||||
("canine", "Canine"),
|
||||
("roformer", "RoFormer"),
|
||||
|
@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
|
||||
MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
("layoutlmv2", "LayoutLMv2Model"),
|
||||
("beit", "BeitModel"),
|
||||
("rembert", "RemBertModel"),
|
||||
("visual_bert", "VisualBertModel"),
|
||||
@ -285,6 +286,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
|
||||
("rembert", "RemBertForSequenceClassification"),
|
||||
("canine", "CanineForSequenceClassification"),
|
||||
("roformer", "RoFormerForSequenceClassification"),
|
||||
@ -327,6 +329,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
|
||||
("rembert", "RemBertForQuestionAnswering"),
|
||||
("canine", "CanineForQuestionAnswering"),
|
||||
("roformer", "RoFormerForQuestionAnswering"),
|
||||
@ -371,6 +374,7 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
||||
("rembert", "RemBertForTokenClassification"),
|
||||
("canine", "CanineForTokenClassification"),
|
||||
("roformer", "RoFormerForTokenClassification"),
|
||||
|
@ -120,6 +120,7 @@ else:
|
||||
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"dpr",
|
||||
(
|
||||
|
71
src/transformers/models/layoutlmv2/__init__.py
Normal file
71
src/transformers/models/layoutlmv2/__init__.py
Normal file
@ -0,0 +1,71 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_layoutlmv2": ["LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMv2Config"],
|
||||
"tokenization_layoutlmv2": ["LayoutLMv2Tokenizer"],
|
||||
}
|
||||
|
||||
if is_tokenizers_available():
|
||||
_import_structure["tokenization_layoutlmv2_fast"] = ["LayoutLMv2TokenizerFast"]
|
||||
|
||||
if is_vision_available():
|
||||
_import_structure["feature_extraction_layoutlmv2"] = ["LayoutLMv2FeatureExtractor"]
|
||||
_import_structure["processing_layoutlmv2"] = ["LayoutLMv2Processor"]
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_layoutlmv2"] = [
|
||||
"LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"LayoutLMv2ForQuestionAnswering",
|
||||
"LayoutLMv2ForSequenceClassification",
|
||||
"LayoutLMv2ForTokenClassification",
|
||||
"LayoutLMv2Layer",
|
||||
"LayoutLMv2Model",
|
||||
"LayoutLMv2PreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_layoutlmv2 import LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMv2Config
|
||||
from .tokenization_layoutlmv2 import LayoutLMv2Tokenizer
|
||||
|
||||
if is_tokenizers_available():
|
||||
from .tokenization_layoutlmv2_fast import LayoutLMv2TokenizerFast
|
||||
|
||||
if is_vision_available():
|
||||
from .feature_extraction_layoutlmv2 import LayoutLMv2FeatureExtractor
|
||||
from .processing_layoutlmv2 import LayoutLMv2Processor
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_layoutlmv2 import (
|
||||
LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LayoutLMv2ForQuestionAnswering,
|
||||
LayoutLMv2ForSequenceClassification,
|
||||
LayoutLMv2ForTokenClassification,
|
||||
LayoutLMv2Layer,
|
||||
LayoutLMv2Model,
|
||||
LayoutLMv2PreTrainedModel,
|
||||
)
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
224
src/transformers/models/layoutlmv2/configuration_layoutlmv2.py
Normal file
224
src/transformers/models/layoutlmv2/configuration_layoutlmv2.py
Normal file
@ -0,0 +1,224 @@
|
||||
# coding=utf-8
|
||||
# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" LayoutLMv2 model configuration """
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import is_detectron2_available
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"layoutlmv2-base-uncased": "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/config.json",
|
||||
"layoutlmv2-large-uncased": "https://huggingface.co/microsoft/layoutlmv2-large-uncased/resolve/main/config.json",
|
||||
# See all LayoutLMv2 models at https://huggingface.co/models?filter=layoutlmv2
|
||||
}
|
||||
|
||||
# soft dependency
|
||||
if is_detectron2_available():
|
||||
import detectron2
|
||||
|
||||
|
||||
class LayoutLMv2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.LayoutLMv2Model`. It is used
|
||||
to instantiate an LayoutLMv2 model according to the specified arguments, defining the model architecture.
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to that of the LayoutLMv2
|
||||
`microsoft/layoutlmv2-base-uncased <https://huggingface.co/microsoft/layoutlmv2-base-uncased>`__ architecture.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 30522):
|
||||
Vocabulary size of the LayoutLMv2 model. Defines the number of different tokens that can be represented by
|
||||
the :obj:`inputs_ids` passed when calling :class:`~transformers.LayoutLMv2Model` or
|
||||
:class:`~transformers.TFLayoutLMv2Model`.
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
||||
Dimension of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
|
||||
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.LayoutLMv2Model`
|
||||
or :class:`~transformers.TFLayoutLMv2Model`.
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
max_2d_position_embeddings (:obj:`int`, `optional`, defaults to 1024):
|
||||
The maximum value that the 2D position embedding might ever be used with. Typically set this to something
|
||||
large just in case (e.g., 1024).
|
||||
max_rel_pos (:obj:`int`, `optional`, defaults to 128):
|
||||
The maximum number of relative positions to be used in the self-attention mechanism.
|
||||
rel_pos_bins (:obj:`int`, `optional`, defaults to 32):
|
||||
The number of relative position bins to be used in the self-attention mechanism.
|
||||
fast_qkv (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use a single matrix for the queries, keys, values in the self-attention layers.
|
||||
max_rel_2d_pos (:obj:`int`, `optional`, defaults to 256):
|
||||
The maximum number of relative 2D positions in the self-attention mechanism.
|
||||
rel_2d_pos_bins (:obj:`int`, `optional`, defaults to 64):
|
||||
The number of 2D relative position bins in the self-attention mechanism.
|
||||
image_feature_pool_shape (:obj:`List[int]`, `optional`, defaults to [7, 7, 256]):
|
||||
The shape of the average-pooled feature map.
|
||||
coordinate_size (:obj:`int`, `optional`, defaults to 128):
|
||||
Dimension of the coordinate embeddings.
|
||||
shape_size (:obj:`int`, `optional`, defaults to 128):
|
||||
Dimension of the width and height embeddings.
|
||||
has_relative_attention_bias (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use a relative attention bias in the self-attention mechanism.
|
||||
has_spatial_attention_bias (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use a spatial attention bias in the self-attention mechanism.
|
||||
has_visual_segment_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to add visual segment embeddings.
|
||||
detectron2_config_args (:obj:`dict`, `optional`):
|
||||
Dictionary containing the configuration arguments of the Detectron2 visual backbone. Refer to `this file
|
||||
<https://github.com/microsoft/unilm/blob/master/layoutlmft/layoutlmft/models/layoutlmv2/detectron2_config.py>`__
|
||||
for details regarding default values.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import LayoutLMv2Model, LayoutLMv2Config
|
||||
|
||||
>>> # Initializing a LayoutLMv2 microsoft/layoutlmv2-base-uncased style configuration
|
||||
>>> configuration = LayoutLMv2Config()
|
||||
|
||||
>>> # Initializing a model from the microsoft/layoutlmv2-base-uncased style configuration
|
||||
>>> model = LayoutLMv2Model(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
"""
|
||||
model_type = "layoutlmv2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
pad_token_id=0,
|
||||
max_2d_position_embeddings=1024,
|
||||
max_rel_pos=128,
|
||||
rel_pos_bins=32,
|
||||
fast_qkv=True,
|
||||
max_rel_2d_pos=256,
|
||||
rel_2d_pos_bins=64,
|
||||
convert_sync_batchnorm=True,
|
||||
image_feature_pool_shape=[7, 7, 256],
|
||||
coordinate_size=128,
|
||||
shape_size=128,
|
||||
has_relative_attention_bias=True,
|
||||
has_spatial_attention_bias=True,
|
||||
has_visual_segment_embedding=False,
|
||||
detectron2_config_args=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=hidden_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=hidden_act,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
type_vocab_size=type_vocab_size,
|
||||
initializer_range=initializer_range,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
pad_token_id=pad_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.max_2d_position_embeddings = max_2d_position_embeddings
|
||||
self.max_rel_pos = max_rel_pos
|
||||
self.rel_pos_bins = rel_pos_bins
|
||||
self.fast_qkv = fast_qkv
|
||||
self.max_rel_2d_pos = max_rel_2d_pos
|
||||
self.rel_2d_pos_bins = rel_2d_pos_bins
|
||||
self.convert_sync_batchnorm = convert_sync_batchnorm
|
||||
self.image_feature_pool_shape = image_feature_pool_shape
|
||||
self.coordinate_size = coordinate_size
|
||||
self.shape_size = shape_size
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
self.has_spatial_attention_bias = has_spatial_attention_bias
|
||||
self.has_visual_segment_embedding = has_visual_segment_embedding
|
||||
self.detectron2_config_args = (
|
||||
detectron2_config_args if detectron2_config_args is not None else self.get_default_detectron2_config()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_default_detectron2_config(self):
|
||||
return {
|
||||
"MODEL.MASK_ON": True,
|
||||
"MODEL.PIXEL_STD": [57.375, 57.120, 58.395],
|
||||
"MODEL.BACKBONE.NAME": "build_resnet_fpn_backbone",
|
||||
"MODEL.FPN.IN_FEATURES": ["res2", "res3", "res4", "res5"],
|
||||
"MODEL.ANCHOR_GENERATOR.SIZES": [[32], [64], [128], [256], [512]],
|
||||
"MODEL.RPN.IN_FEATURES": ["p2", "p3", "p4", "p5", "p6"],
|
||||
"MODEL.RPN.PRE_NMS_TOPK_TRAIN": 2000,
|
||||
"MODEL.RPN.PRE_NMS_TOPK_TEST": 1000,
|
||||
"MODEL.RPN.POST_NMS_TOPK_TRAIN": 1000,
|
||||
"MODEL.POST_NMS_TOPK_TEST": 1000,
|
||||
"MODEL.ROI_HEADS.NAME": "StandardROIHeads",
|
||||
"MODEL.ROI_HEADS.NUM_CLASSES": 5,
|
||||
"MODEL.ROI_HEADS.IN_FEATURES": ["p2", "p3", "p4", "p5"],
|
||||
"MODEL.ROI_BOX_HEAD.NAME": "FastRCNNConvFCHead",
|
||||
"MODEL.ROI_BOX_HEAD.NUM_FC": 2,
|
||||
"MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION": 14,
|
||||
"MODEL.ROI_MASK_HEAD.NAME": "MaskRCNNConvUpsampleHead",
|
||||
"MODEL.ROI_MASK_HEAD.NUM_CONV": 4,
|
||||
"MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION": 7,
|
||||
"MODEL.RESNETS.DEPTH": 101,
|
||||
"MODEL.RESNETS.SIZES": [[32], [64], [128], [256], [512]],
|
||||
"MODEL.RESNETS.ASPECT_RATIOS": [[0.5, 1.0, 2.0]],
|
||||
"MODEL.RESNETS.OUT_FEATURES": ["res2", "res3", "res4", "res5"],
|
||||
"MODEL.RESNETS.NUM_GROUPS": 32,
|
||||
"MODEL.RESNETS.WIDTH_PER_GROUP": 8,
|
||||
"MODEL.RESNETS.STRIDE_IN_1X1": False,
|
||||
}
|
||||
|
||||
def get_detectron2_config(self):
|
||||
detectron2_config = detectron2.config.get_cfg()
|
||||
for k, v in self.detectron2_config_args.items():
|
||||
attributes = k.split(".")
|
||||
to_set = detectron2_config
|
||||
for attribute in attributes[:-1]:
|
||||
to_set = getattr(to_set, attribute)
|
||||
setattr(to_set, attributes[-1], v)
|
||||
|
||||
return detectron2_config
|
@ -0,0 +1,222 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Feature extractor class for LayoutLMv2.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||
from ...file_utils import TensorType, is_pytesseract_available, requires_backends
|
||||
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
# soft dependency
|
||||
if is_pytesseract_available():
|
||||
import pytesseract
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
ImageInput = Union[
|
||||
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
|
||||
]
|
||||
|
||||
|
||||
def normalize_box(box, width, height):
|
||||
return [
|
||||
int(1000 * (box[0] / width)),
|
||||
int(1000 * (box[1] / height)),
|
||||
int(1000 * (box[2] / width)),
|
||||
int(1000 * (box[3] / height)),
|
||||
]
|
||||
|
||||
|
||||
def apply_tesseract(image: Image.Image):
|
||||
"""Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
|
||||
|
||||
# apply OCR
|
||||
data = pytesseract.image_to_data(image, output_type="dict")
|
||||
words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
|
||||
|
||||
# filter empty words and corresponding coordinates
|
||||
irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
|
||||
words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
|
||||
left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
|
||||
top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
|
||||
width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
|
||||
height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
|
||||
|
||||
# turn coordinates into (left, top, left+width, top+height) format
|
||||
actual_boxes = []
|
||||
for x, y, w, h in zip(left, top, width, height):
|
||||
actual_box = [x, y, x + w, y + h]
|
||||
actual_boxes.append(actual_box)
|
||||
|
||||
image_width, image_height = image.size
|
||||
|
||||
# finally, normalize the bounding boxes
|
||||
normalized_boxes = []
|
||||
for box in actual_boxes:
|
||||
normalized_boxes.append(normalize_box(box, image_width, image_height))
|
||||
|
||||
assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
|
||||
|
||||
return words, normalized_boxes
|
||||
|
||||
|
||||
class LayoutLMv2FeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
r"""
|
||||
Constructs a LayoutLMv2 feature extractor. This can be used to resize document images to the same size, as well as
|
||||
to apply OCR on them in order to get a list of words and normalized bounding boxes.
|
||||
|
||||
This feature extractor inherits from :class:`~transformers.feature_extraction_utils.PreTrainedFeatureExtractor`
|
||||
which contains most of the main methods. Users should refer to this superclass for more information regarding those
|
||||
methods.
|
||||
|
||||
Args:
|
||||
do_resize (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to resize the input to a certain :obj:`size`.
|
||||
size (:obj:`int` or :obj:`Tuple(int)`, `optional`, defaults to 224):
|
||||
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
|
||||
integer is provided, then the input will be resized to (size, size). Only has an effect if :obj:`do_resize`
|
||||
is set to :obj:`True`.
|
||||
resample (:obj:`int`, `optional`, defaults to :obj:`PIL.Image.BILINEAR`):
|
||||
An optional resampling filter. This can be one of :obj:`PIL.Image.NEAREST`, :obj:`PIL.Image.BOX`,
|
||||
:obj:`PIL.Image.BILINEAR`, :obj:`PIL.Image.HAMMING`, :obj:`PIL.Image.BICUBIC` or :obj:`PIL.Image.LANCZOS`.
|
||||
Only has an effect if :obj:`do_resize` is set to :obj:`True`.
|
||||
apply_ocr (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to apply the Tesseract OCR engine to get words + normalized bounding boxes.
|
||||
|
||||
.. note::
|
||||
|
||||
LayoutLMv2FeatureExtractor uses Google's Tesseract OCR engine under the hood.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(self, do_resize=True, size=224, resample=Image.BILINEAR, apply_ocr=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.apply_ocr = apply_ocr
|
||||
if apply_ocr:
|
||||
requires_backends(self, "pytesseract")
|
||||
|
||||
def __call__(
|
||||
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several image(s).
|
||||
|
||||
Args:
|
||||
images (:obj:`PIL.Image.Image`, :obj:`np.ndarray`, :obj:`torch.Tensor`, :obj:`List[PIL.Image.Image]`, :obj:`List[np.ndarray]`, :obj:`List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||
number of channels, H and W are image height and width.
|
||||
return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`, defaults to :obj:`'np'`):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return NumPy :obj:`np.ndarray` objects.
|
||||
* :obj:`'jax'`: Return JAX :obj:`jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
:class:`~transformers.BatchFeature`: A :class:`~transformers.BatchFeature` with the following fields:
|
||||
|
||||
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
||||
width).
|
||||
- **words** -- Optional words as identified by Tesseract OCR (only when
|
||||
:class:`~transformers.LayoutLMv2FeatureExtractor` was initialized with :obj:`apply_ocr` set to ``True``).
|
||||
- **boxes** -- Optional bounding boxes as identified by Tesseract OCR, normalized based on the image size
|
||||
(only when :class:`~transformers.LayoutLMv2FeatureExtractor` was initialized with :obj:`apply_ocr` set to
|
||||
``True``).
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import LayoutLMv2FeatureExtractor
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> image = Image.open("name_of_your_document - can be a png file, pdf, etc.").convert("RGB")
|
||||
|
||||
>>> # option 1: with apply_ocr=True (default)
|
||||
>>> feature_extractor = LayoutLMv2FeatureExtractor()
|
||||
>>> encoding = feature_extractor(image, return_tensors="pt")
|
||||
>>> print(encoding.keys())
|
||||
>>> # dict_keys(['pixel_values', 'words', 'boxes'])
|
||||
|
||||
>>> # option 2: with apply_ocr=False
|
||||
>>> feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
||||
>>> encoding = feature_extractor(image, return_tensors="pt")
|
||||
>>> print(encoding.keys())
|
||||
>>> # dict_keys(['pixel_values'])
|
||||
"""
|
||||
|
||||
# Input type checking for clearer error
|
||||
valid_images = False
|
||||
|
||||
# Check that images has a valid type
|
||||
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
|
||||
valid_images = True
|
||||
elif isinstance(images, (list, tuple)):
|
||||
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
|
||||
valid_images = True
|
||||
|
||||
if not valid_images:
|
||||
raise ValueError(
|
||||
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
|
||||
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples), "
|
||||
f"but is of type {type(images)}."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(images, (list, tuple))
|
||||
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
|
||||
)
|
||||
|
||||
if not is_batched:
|
||||
images = [images]
|
||||
|
||||
# Tesseract OCR to get words + normalized bounding boxes
|
||||
if self.apply_ocr:
|
||||
words_batch = []
|
||||
boxes_batch = []
|
||||
for image in images:
|
||||
words, boxes = apply_tesseract(self.to_pil_image(image))
|
||||
words_batch.append(words)
|
||||
boxes_batch.append(boxes)
|
||||
|
||||
# transformations (resizing)
|
||||
if self.do_resize and self.size is not None:
|
||||
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
|
||||
|
||||
images = [self.to_numpy_array(image, rescale=False) for image in images]
|
||||
# flip color channels from RGB to BGR (as Detectron2 requires this)
|
||||
images = [image[::-1, :, :] for image in images]
|
||||
|
||||
# return as BatchFeature
|
||||
data = {"pixel_values": images}
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
if self.apply_ocr:
|
||||
encoded_inputs["words"] = words_batch
|
||||
encoded_inputs["boxes"] = boxes_batch
|
||||
|
||||
return encoded_inputs
|
1338
src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
Executable file
1338
src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
Executable file
File diff suppressed because it is too large
Load Diff
207
src/transformers/models/layoutlmv2/processing_layoutlmv2.py
Normal file
207
src/transformers/models/layoutlmv2/processing_layoutlmv2.py
Normal file
@ -0,0 +1,207 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for LayoutLMv2.
|
||||
"""
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...file_utils import TensorType
|
||||
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from .feature_extraction_layoutlmv2 import LayoutLMv2FeatureExtractor
|
||||
from .tokenization_layoutlmv2 import LayoutLMv2Tokenizer
|
||||
from .tokenization_layoutlmv2_fast import LayoutLMv2TokenizerFast
|
||||
|
||||
|
||||
class LayoutLMv2Processor:
|
||||
r"""
|
||||
Constructs a LayoutLMv2 processor which combines a LayoutLMv2 feature extractor and a LayoutLMv2 tokenizer into a
|
||||
single processor.
|
||||
|
||||
:class:`~transformers.LayoutLMv2Processor` offers all the functionalities you need to prepare data for the model.
|
||||
|
||||
It first uses :class:`~transformers.LayoutLMv2FeatureExtractor` to resize document images to a fixed size, and
|
||||
optionally applies OCR to get words and normalized bounding boxes. These are then provided to
|
||||
:class:`~transformers.LayoutLMv2Tokenizer` or :class:`~transformers.LayoutLMv2TokenizerFast`, which turns the words
|
||||
and bounding boxes into token-level :obj:`input_ids`, :obj:`attention_mask`, :obj:`token_type_ids`, :obj:`bbox`.
|
||||
Optionally, one can provide integer :obj:`word_labels`, which are turned into token-level :obj:`labels` for token
|
||||
classification tasks (such as FUNSD, CORD).
|
||||
|
||||
Args:
|
||||
feature_extractor (:obj:`LayoutLMv2FeatureExtractor`):
|
||||
An instance of :class:`~transformers.LayoutLMv2FeatureExtractor`. The feature extractor is a required
|
||||
input.
|
||||
tokenizer (:obj:`LayoutLMv2Tokenizer` or :obj:`LayoutLMv2TokenizerFast`):
|
||||
An instance of :class:`~transformers.LayoutLMv2Tokenizer` or
|
||||
:class:`~transformers.LayoutLMv2TokenizerFast`. The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
def __init__(self, feature_extractor, tokenizer):
|
||||
if not isinstance(feature_extractor, LayoutLMv2FeatureExtractor):
|
||||
raise ValueError(
|
||||
f"`feature_extractor` has to be of type {LayoutLMv2FeatureExtractor.__class__}, but is {type(feature_extractor)}"
|
||||
)
|
||||
if not isinstance(tokenizer, (LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast)):
|
||||
raise ValueError(
|
||||
f"`tokenizer` has to be of type {LayoutLMv2Tokenizer.__class__} or {LayoutLMv2TokenizerFast.__class__}, but is {type(tokenizer)}"
|
||||
)
|
||||
|
||||
self.feature_extractor = feature_extractor
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
"""
|
||||
Save a LayoutLMv2 feature_extractor object and LayoutLMv2 tokenizer object to the directory ``save_directory``,
|
||||
so that it can be re-loaded using the :func:`~transformers.LayoutLMv2Processor.from_pretrained` class method.
|
||||
|
||||
.. note::
|
||||
|
||||
This class method is simply calling
|
||||
:meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` and
|
||||
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the
|
||||
docstrings of the methods above for more information.
|
||||
|
||||
Args:
|
||||
save_directory (:obj:`str` or :obj:`os.PathLike`):
|
||||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
||||
be created if it does not exist).
|
||||
"""
|
||||
|
||||
self.feature_extractor.save_pretrained(save_directory)
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, use_fast=True, **kwargs):
|
||||
r"""
|
||||
Instantiate a :class:`~transformers.LayoutLMv2Processor` from a pretrained LayoutLMv2 processor.
|
||||
|
||||
.. note::
|
||||
|
||||
This class method is simply calling LayoutLMv2FeatureExtractor's
|
||||
:meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.from_pretrained` and
|
||||
LayoutLMv2TokenizerFast's
|
||||
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. Please refer to the
|
||||
docstrings of the methods above for more information.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on
|
||||
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
|
||||
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing a feature extractor file saved using the
|
||||
:meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g.,
|
||||
``./my_model_directory/``.
|
||||
- a path or url to a saved feature extractor JSON `file`, e.g.,
|
||||
``./my_model_directory/preprocessor_config.json``.
|
||||
|
||||
use_fast (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to instantiate a fast tokenizer.
|
||||
|
||||
**kwargs
|
||||
Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and
|
||||
:class:`~transformers.PreTrainedTokenizer`
|
||||
"""
|
||||
feature_extractor = LayoutLMv2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
if use_fast:
|
||||
tokenizer = LayoutLMv2TokenizerFast.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
else:
|
||||
tokenizer = LayoutLMv2Tokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
|
||||
boxes: Union[List[List[int]], List[List[List[int]]]] = None,
|
||||
word_labels: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
This method first forwards the :obj:`images` argument to
|
||||
:meth:`~transformers.LayoutLMv2FeatureExtractor.__call__`. In case :class:`~LayoutLMv2FeatureExtractor` was
|
||||
initialized with :obj:`apply_ocr` set to ``True``, it passes the obtained words and bounding boxes along with
|
||||
the additional arguments to :meth:`~transformers.LayoutLMv2Tokenizer.__call__` and returns the output, together
|
||||
with resized :obj:`images`. In case :class:`~LayoutLMv2FeatureExtractor` was initialized with :obj:`apply_ocr`
|
||||
set to ``False``, it passes the words (:obj:`text`/:obj:`text_pair`) and :obj:`boxes` specified by the user
|
||||
along with the additional arguments to :meth:`~transformers.LayoutLMv2Tokenizer.__call__` and returns the
|
||||
output, together with resized :obj:`images`.
|
||||
|
||||
Please refer to the docstring of the above two methods for more information.
|
||||
"""
|
||||
# verify input
|
||||
if self.feature_extractor.apply_ocr and (boxes is not None):
|
||||
raise ValueError(
|
||||
"You cannot provide bounding boxes "
|
||||
"if you initialized the feature extractor with apply_ocr set to True."
|
||||
)
|
||||
|
||||
if self.feature_extractor.apply_ocr and (word_labels is not None):
|
||||
raise ValueError(
|
||||
"You cannot provide word labels "
|
||||
"if you initialized the feature extractor with apply_ocr set to True."
|
||||
)
|
||||
|
||||
# first, apply the feature extractor
|
||||
features = self.feature_extractor(images=images, return_tensors=return_tensors)
|
||||
|
||||
# second, apply the tokenizer
|
||||
if text is not None and self.feature_extractor.apply_ocr and text_pair is None:
|
||||
if isinstance(text, str):
|
||||
text = [text] # add batch dimension (as the feature extractor always adds a batch dimension)
|
||||
text_pair = features["words"]
|
||||
|
||||
encoded_inputs = self.tokenizer(
|
||||
text=text if text is not None else features["words"],
|
||||
text_pair=text_pair if text_pair is not None else None,
|
||||
boxes=boxes if boxes is not None else features["boxes"],
|
||||
word_labels=word_labels,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# add pixel values
|
||||
encoded_inputs["image"] = features.pop("pixel_values")
|
||||
|
||||
return encoded_inputs
|
1478
src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py
Normal file
1478
src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,806 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fast tokenization class for LayoutLMv2. It overwrites 2 methods of the slow tokenizer class, namely _batch_encode_plus
|
||||
and _encode_plus, in which the Rust tokenizer is used.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from tokenizers import normalizers
|
||||
|
||||
from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings
|
||||
from ...tokenization_utils_base import (
|
||||
ENCODE_KWARGS_DOCSTRING,
|
||||
BatchEncoding,
|
||||
EncodedInput,
|
||||
PreTokenizedInput,
|
||||
TextInput,
|
||||
TextInputPair,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
from .tokenization_layoutlmv2 import LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING, LayoutLMv2Tokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"microsoft/layoutlmv2-base-uncased": "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/vocab.txt",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"microsoft/layoutlmv2-base-uncased": "https://huggingface.co/microsoft/layoutlmv2-base-uncased/resolve/main/tokenizer.json",
|
||||
},
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"microsoft/layoutlmv2-base-uncased": 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"microsoft/layoutlmv2-base-uncased": {"do_lower_case": True},
|
||||
}
|
||||
|
||||
|
||||
class LayoutLMv2TokenizerFast(PreTrainedTokenizerFast):
|
||||
r"""
|
||||
Construct a "fast" LayoutLMv2 tokenizer (backed by HuggingFace's `tokenizers` library). Based on WordPiece.
|
||||
|
||||
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
|
||||
methods. Users should refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (:obj:`str`):
|
||||
File containing the vocabulary.
|
||||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to lowercase the input when tokenizing.
|
||||
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
|
||||
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||
token of a sequence built with special tokens.
|
||||
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
|
||||
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
||||
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
||||
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
cls_token_box (:obj:`List[int]`, `optional`, defaults to :obj:`[0, 0, 0, 0]`):
|
||||
The bounding box to use for the special [CLS] token.
|
||||
sep_token_box (:obj:`List[int]`, `optional`, defaults to :obj:`[1000, 1000, 1000, 1000]`):
|
||||
The bounding box to use for the special [SEP] token.
|
||||
pad_token_box (:obj:`List[int]`, `optional`, defaults to :obj:`[0, 0, 0, 0]`):
|
||||
The bounding box to use for the special [PAD] token.
|
||||
pad_token_label (:obj:`int`, `optional`, defaults to -100):
|
||||
The label to use for padding tokens. Defaults to -100, which is the :obj:`ignore_index` of PyTorch's
|
||||
CrossEntropyLoss.
|
||||
only_label_first_subword (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to only label the first subword, in case word labels are provided.
|
||||
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see `this
|
||||
issue <https://github.com/huggingface/transformers/issues/328>`__).
|
||||
strip_accents: (:obj:`bool`, `optional`):
|
||||
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
||||
value for :obj:`lowercase` (as in the original LayoutLMv2).
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
slow_tokenizer_class = LayoutLMv2Tokenizer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file=None,
|
||||
tokenizer_file=None,
|
||||
do_lower_case=True,
|
||||
unk_token="[UNK]",
|
||||
sep_token="[SEP]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
mask_token="[MASK]",
|
||||
cls_token_box=[0, 0, 0, 0],
|
||||
sep_token_box=[1000, 1000, 1000, 1000],
|
||||
pad_token_box=[0, 0, 0, 0],
|
||||
pad_token_label=-100,
|
||||
only_label_first_subword=True,
|
||||
tokenize_chinese_chars=True,
|
||||
strip_accents=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
vocab_file,
|
||||
tokenizer_file=tokenizer_file,
|
||||
do_lower_case=do_lower_case,
|
||||
unk_token=unk_token,
|
||||
sep_token=sep_token,
|
||||
pad_token=pad_token,
|
||||
cls_token=cls_token,
|
||||
mask_token=mask_token,
|
||||
cls_token_box=cls_token_box,
|
||||
sep_token_box=sep_token_box,
|
||||
pad_token_box=pad_token_box,
|
||||
pad_token_label=pad_token_label,
|
||||
only_label_first_subword=only_label_first_subword,
|
||||
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||
strip_accents=strip_accents,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
||||
if (
|
||||
pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
|
||||
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
|
||||
):
|
||||
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
|
||||
pre_tok_state["lowercase"] = do_lower_case
|
||||
pre_tok_state["strip_accents"] = strip_accents
|
||||
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
|
||||
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
# additional properties
|
||||
self.cls_token_box = cls_token_box
|
||||
self.sep_token_box = sep_token_box
|
||||
self.pad_token_box = pad_token_box
|
||||
self.pad_token_label = pad_token_label
|
||||
self.only_label_first_subword = only_label_first_subword
|
||||
|
||||
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
|
||||
text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
|
||||
boxes: Union[List[List[int]], List[List[List[int]]]] = None,
|
||||
word_labels: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
|
||||
sequences with word-level normalized bounding boxes and optional labels.
|
||||
|
||||
Args:
|
||||
text (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings
|
||||
(words of a single example or questions of a batch of examples) or a list of list of strings (batch of
|
||||
words).
|
||||
text_pair (:obj:`List[str]`, :obj:`List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence should be a list of strings
|
||||
(pretokenized string).
|
||||
boxes (:obj:`List[List[int]]`, :obj:`List[List[List[int]]]`):
|
||||
Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale.
|
||||
word_labels (:obj:`List[int]`, :obj:`List[List[int]]`, `optional`):
|
||||
Word-level integer labels (for token classification tasks such as FUNSD, CORD).
|
||||
"""
|
||||
# Input type checking for clearer error
|
||||
def _is_valid_text_input(t):
|
||||
if isinstance(t, str):
|
||||
# Strings are fine
|
||||
return True
|
||||
elif isinstance(t, (list, tuple)):
|
||||
# List are fine as long as they are...
|
||||
if len(t) == 0:
|
||||
# ... empty
|
||||
return True
|
||||
elif isinstance(t[0], str):
|
||||
# ... list of strings
|
||||
return True
|
||||
elif isinstance(t[0], (list, tuple)):
|
||||
# ... list with an empty list or with a list of strings
|
||||
return len(t[0]) == 0 or isinstance(t[0][0], str)
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
if text_pair is not None:
|
||||
# in case text + text_pair are provided, text = questions, text_pair = words
|
||||
if not _is_valid_text_input(text):
|
||||
raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ")
|
||||
if not isinstance(text_pair, (list, tuple)):
|
||||
raise ValueError(
|
||||
"words must of type `List[str]` (single pretokenized example),"
|
||||
"or `List[List[str]]` (batch of pretokenized examples)."
|
||||
)
|
||||
else:
|
||||
# in case only text is provided => must be words
|
||||
if not isinstance(text, (list, tuple)):
|
||||
raise ValueError(
|
||||
"Words must of type `List[str]` (single pretokenized example), "
|
||||
"or `List[List[str]]` (batch of pretokenized examples)."
|
||||
)
|
||||
|
||||
if text_pair is not None:
|
||||
is_batched = isinstance(text, (list, tuple))
|
||||
else:
|
||||
is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
|
||||
|
||||
words = text if text_pair is None else text_pair
|
||||
assert boxes is not None, "You must provide corresponding bounding boxes"
|
||||
if is_batched:
|
||||
assert len(words) == len(boxes), "You must provide words and boxes for an equal amount of examples"
|
||||
for words_example, boxes_example in zip(words, boxes):
|
||||
assert len(words_example) == len(
|
||||
boxes_example
|
||||
), "You must provide as many words as there are bounding boxes"
|
||||
else:
|
||||
assert len(words) == len(boxes), "You must provide as many words as there are bounding boxes"
|
||||
|
||||
if is_batched:
|
||||
if text_pair is not None and len(text) != len(text_pair):
|
||||
raise ValueError(
|
||||
f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}."
|
||||
)
|
||||
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
|
||||
is_pair = bool(text_pair is not None)
|
||||
return self.batch_encode_plus(
|
||||
batch_text_or_text_pairs=batch_text_or_text_pairs,
|
||||
is_pair=is_pair,
|
||||
boxes=boxes,
|
||||
word_labels=word_labels,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self.encode_plus(
|
||||
text=text,
|
||||
text_pair=text_pair,
|
||||
boxes=boxes,
|
||||
word_labels=word_labels,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
||||
def batch_encode_plus(
|
||||
self,
|
||||
batch_text_or_text_pairs: Union[
|
||||
List[TextInput],
|
||||
List[TextInputPair],
|
||||
List[PreTokenizedInput],
|
||||
],
|
||||
is_pair: bool = None,
|
||||
boxes: Optional[List[List[List[int]]]] = None,
|
||||
word_labels: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
""" """
|
||||
|
||||
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
|
||||
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self._batch_encode_plus(
|
||||
batch_text_or_text_pairs=batch_text_or_text_pairs,
|
||||
is_pair=is_pair,
|
||||
boxes=boxes,
|
||||
word_labels=word_labels,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding_strategy=padding_strategy,
|
||||
truncation_strategy=truncation_strategy,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
|
||||
batched_input = [(text, pair)] if pair else [text]
|
||||
encodings = self._tokenizer.encode_batch(
|
||||
batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs
|
||||
)
|
||||
|
||||
return encodings[0].tokens
|
||||
|
||||
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, LAYOUTLMV2_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
||||
def encode_plus(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput],
|
||||
text_pair: Optional[PreTokenizedInput] = None,
|
||||
boxes: Optional[List[List[int]]] = None,
|
||||
word_labels: Optional[List[int]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Tokenize and prepare for the model a sequence or a pair of sequences. .. warning:: This method is deprecated,
|
||||
``__call__`` should be used instead.
|
||||
|
||||
Args:
|
||||
text (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
||||
The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings.
|
||||
text_pair (:obj:`List[str]` or :obj:`List[int]`, `optional`):
|
||||
Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a
|
||||
list of list of strings (words of a batch of examples).
|
||||
"""
|
||||
|
||||
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
|
||||
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self._encode_plus(
|
||||
text=text,
|
||||
boxes=boxes,
|
||||
text_pair=text_pair,
|
||||
word_labels=word_labels,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding_strategy=padding_strategy,
|
||||
truncation_strategy=truncation_strategy,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _batch_encode_plus(
|
||||
self,
|
||||
batch_text_or_text_pairs: Union[
|
||||
List[TextInput],
|
||||
List[TextInputPair],
|
||||
List[PreTokenizedInput],
|
||||
],
|
||||
is_pair: bool = None,
|
||||
boxes: Optional[List[List[List[int]]]] = None,
|
||||
word_labels: Optional[List[List[int]]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[str] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
) -> BatchEncoding:
|
||||
|
||||
if not isinstance(batch_text_or_text_pairs, list):
|
||||
raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})")
|
||||
|
||||
# Set the truncation and padding strategy and restore the initial configuration
|
||||
self.set_truncation_and_padding(
|
||||
padding_strategy=padding_strategy,
|
||||
truncation_strategy=truncation_strategy,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
)
|
||||
|
||||
if is_pair:
|
||||
batch_text_or_text_pairs = [(text.split(), text_pair) for text, text_pair in batch_text_or_text_pairs]
|
||||
|
||||
encodings = self._tokenizer.encode_batch(
|
||||
batch_text_or_text_pairs,
|
||||
add_special_tokens=add_special_tokens,
|
||||
is_pretokenized=True, # we set this to True as LayoutLMv2 always expects pretokenized inputs
|
||||
)
|
||||
|
||||
# Convert encoding to dict
|
||||
# `Tokens` has type: Tuple[
|
||||
# List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
|
||||
# List[EncodingFast]
|
||||
# ]
|
||||
# with nested dimensions corresponding to batch, overflows, sequence length
|
||||
tokens_and_encodings = [
|
||||
self._convert_encoding(
|
||||
encoding=encoding,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=True
|
||||
if word_labels is not None
|
||||
else return_offsets_mapping, # we use offsets to create the labels
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
)
|
||||
for encoding in encodings
|
||||
]
|
||||
|
||||
# Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
|
||||
# From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
|
||||
# (we say ~ because the number of overflow varies with the example in the batch)
|
||||
#
|
||||
# To match each overflowing sample with the original sample in the batch
|
||||
# we add an overflow_to_sample_mapping array (see below)
|
||||
sanitized_tokens = {}
|
||||
for key in tokens_and_encodings[0][0].keys():
|
||||
stack = [e for item, _ in tokens_and_encodings for e in item[key]]
|
||||
sanitized_tokens[key] = stack
|
||||
sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
|
||||
|
||||
# If returning overflowing tokens, we need to return a mapping
|
||||
# from the batch idx to the original sample
|
||||
if return_overflowing_tokens:
|
||||
overflow_to_sample_mapping = []
|
||||
for i, (toks, _) in enumerate(tokens_and_encodings):
|
||||
overflow_to_sample_mapping += [i] * len(toks["input_ids"])
|
||||
sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
|
||||
|
||||
for input_ids in sanitized_tokens["input_ids"]:
|
||||
self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
|
||||
|
||||
# create the token boxes
|
||||
token_boxes = []
|
||||
for batch_index in range(len(sanitized_tokens["input_ids"])):
|
||||
if return_overflowing_tokens:
|
||||
original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
|
||||
else:
|
||||
original_index = batch_index
|
||||
token_boxes_example = []
|
||||
for id, sequence_id, word_id in zip(
|
||||
sanitized_tokens["input_ids"][batch_index],
|
||||
sanitized_encodings[batch_index].sequence_ids,
|
||||
sanitized_encodings[batch_index].word_ids,
|
||||
):
|
||||
if word_id is not None:
|
||||
if is_pair and sequence_id == 0:
|
||||
token_boxes_example.append(self.pad_token_box)
|
||||
else:
|
||||
token_boxes_example.append(boxes[original_index][word_id])
|
||||
else:
|
||||
if id == self.cls_token_id:
|
||||
token_boxes_example.append(self.cls_token_box)
|
||||
elif id == self.sep_token_id:
|
||||
token_boxes_example.append(self.sep_token_box)
|
||||
elif id == self.pad_token_id:
|
||||
token_boxes_example.append(self.pad_token_box)
|
||||
else:
|
||||
raise ValueError("Id not recognized")
|
||||
token_boxes.append(token_boxes_example)
|
||||
|
||||
sanitized_tokens["bbox"] = token_boxes
|
||||
|
||||
# optionally, create the labels
|
||||
if word_labels is not None:
|
||||
labels = []
|
||||
for batch_index in range(len(sanitized_tokens["input_ids"])):
|
||||
if return_overflowing_tokens:
|
||||
original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index]
|
||||
else:
|
||||
original_index = batch_index
|
||||
labels_example = []
|
||||
for id, offset, word_id in zip(
|
||||
sanitized_tokens["input_ids"][batch_index],
|
||||
sanitized_tokens["offset_mapping"][batch_index],
|
||||
sanitized_encodings[batch_index].word_ids,
|
||||
):
|
||||
if word_id is not None:
|
||||
if self.only_label_first_subword:
|
||||
if offset[0] == 0:
|
||||
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
|
||||
labels_example.append(word_labels[original_index][word_id])
|
||||
else:
|
||||
labels_example.append(self.pad_token_label)
|
||||
else:
|
||||
labels_example.append(word_labels[original_index][word_id])
|
||||
else:
|
||||
labels_example.append(self.pad_token_label)
|
||||
labels.append(labels_example)
|
||||
|
||||
sanitized_tokens["labels"] = labels
|
||||
# finally, remove offsets if the user didn't want them
|
||||
if not return_offsets_mapping:
|
||||
del sanitized_tokens["offset_mapping"]
|
||||
|
||||
return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
|
||||
|
||||
def _encode_plus(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput],
|
||||
text_pair: Optional[PreTokenizedInput] = None,
|
||||
boxes: Optional[List[List[int]]] = None,
|
||||
word_labels: Optional[List[int]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[bool] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
|
||||
# make it a batched input
|
||||
# 2 options:
|
||||
# 1) only text, in case text must be a list of str
|
||||
# 2) text + text_pair, in which case text = str and text_pair a list of str
|
||||
batched_input = [(text, text_pair)] if text_pair else [text]
|
||||
batched_boxes = [boxes]
|
||||
batched_word_labels = [word_labels] if word_labels is not None else None
|
||||
batched_output = self._batch_encode_plus(
|
||||
batched_input,
|
||||
is_pair=bool(text_pair is not None),
|
||||
boxes=batched_boxes,
|
||||
word_labels=batched_word_labels,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding_strategy=padding_strategy,
|
||||
truncation_strategy=truncation_strategy,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Return tensor is None, then we can remove the leading batch axis
|
||||
# Overflowing tokens are returned as a batch of output so we keep them in this case
|
||||
if return_tensors is None and not return_overflowing_tokens:
|
||||
batched_output = BatchEncoding(
|
||||
{
|
||||
key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
|
||||
for key, value in batched_output.items()
|
||||
},
|
||||
batched_output.encodings,
|
||||
)
|
||||
|
||||
self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
|
||||
|
||||
return batched_output
|
||||
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
max_length: Optional[int] = None,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
|
||||
Args:
|
||||
encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
||||
max_length: maximum length of the returned list and optionally padding length (see below).
|
||||
Will truncate by taking into account the special tokens.
|
||||
padding_strategy: PaddingStrategy to use for padding.
|
||||
|
||||
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
||||
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
||||
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
||||
The tokenizer padding sides are defined in self.padding_side:
|
||||
|
||||
- 'left': pads on the left of the sequences
|
||||
- 'right': pads on the right of the sequences
|
||||
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
||||
>= 7.5 (Volta).
|
||||
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
||||
"""
|
||||
# Load from model defaults
|
||||
if return_attention_mask is None:
|
||||
return_attention_mask = "attention_mask" in self.model_input_names
|
||||
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = len(required_input)
|
||||
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
if self.padding_side == "right":
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
|
||||
if "token_type_ids" in encoded_inputs:
|
||||
encoded_inputs["token_type_ids"] = (
|
||||
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
|
||||
)
|
||||
if "bbox" in encoded_inputs:
|
||||
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference
|
||||
if "labels" in encoded_inputs:
|
||||
encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference
|
||||
if "special_tokens_mask" in encoded_inputs:
|
||||
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
|
||||
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
|
||||
elif self.padding_side == "left":
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
|
||||
if "token_type_ids" in encoded_inputs:
|
||||
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
||||
"token_type_ids"
|
||||
]
|
||||
if "bbox" in encoded_inputs:
|
||||
encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"]
|
||||
if "labels" in encoded_inputs:
|
||||
encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["bbox"]
|
||||
if "special_tokens_mask" in encoded_inputs:
|
||||
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
else:
|
||||
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
||||
elif return_attention_mask and "attention_mask" not in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [1] * len(required_input)
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A BERT sequence has the following format:
|
||||
|
||||
- single sequence: ``[CLS] X [SEP]``
|
||||
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||
|
||||
if token_ids_1:
|
||||
output += token_ids_1 + [self.sep_token_id]
|
||||
|
||||
return output
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
|
||||
pair mask has the following format: :: 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second
|
||||
sequence | If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
||||
sequence(s).
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||
return tuple(files)
|
@ -32,11 +32,13 @@ from transformers import logging as transformers_logging
|
||||
from .deepspeed import is_deepspeed_available
|
||||
from .file_utils import (
|
||||
is_datasets_available,
|
||||
is_detectron2_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
is_keras2onnx_available,
|
||||
is_onnx_available,
|
||||
is_pandas_available,
|
||||
is_pytesseract_available,
|
||||
is_rjieba_available,
|
||||
is_scatter_available,
|
||||
is_sentencepiece_available,
|
||||
@ -348,6 +350,16 @@ def require_pandas(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_pytesseract(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed.
|
||||
"""
|
||||
if not is_pytesseract_available():
|
||||
return unittest.skip("test requires PyTesseract")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_scatter(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
|
||||
@ -457,6 +469,14 @@ def require_datasets(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_detectron2(test_case):
|
||||
"""Decorator marking a test that requires detectron2."""
|
||||
if not is_detectron2_available():
|
||||
return unittest.skip("test requires `detectron2`")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_faiss(test_case):
|
||||
"""Decorator marking a test that requires faiss."""
|
||||
if not is_faiss_available():
|
||||
|
14
src/transformers/utils/dummy_detectron2_objects.py
Normal file
14
src/transformers/utils/dummy_detectron2_objects.py
Normal file
@ -0,0 +1,14 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..file_utils import requires_backends
|
||||
|
||||
|
||||
LAYOUTLM_V2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class LayoutLMv2Model:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["detectron2"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["detectron2"])
|
@ -2004,6 +2004,54 @@ class LayoutLMPreTrainedModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class LayoutLMv2ForQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LayoutLMv2ForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LayoutLMv2ForTokenClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LayoutLMv2Model:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LayoutLMv2PreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
LED_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -164,6 +164,15 @@ class LayoutLMTokenizerFast:
|
||||
requires_backends(cls, ["tokenizers"])
|
||||
|
||||
|
||||
class LayoutLMv2TokenizerFast:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tokenizers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tokenizers"])
|
||||
|
||||
|
||||
class LEDTokenizerFast:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tokenizers"])
|
||||
|
@ -36,6 +36,20 @@ class DetrFeatureExtractor:
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class LayoutLMv2FeatureExtractor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class LayoutLMv2Processor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["vision"])
|
||||
|
||||
|
||||
class ViTFeatureExtractor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
221
tests/test_feature_extraction_layoutlmv2.py
Normal file
221
tests/test_feature_extraction_layoutlmv2.py
Normal file
@ -0,0 +1,221 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.file_utils import is_pytesseract_available, is_torch_available
|
||||
from transformers.testing_utils import require_pytesseract, require_torch
|
||||
|
||||
from .test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_pytesseract_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import LayoutLMv2FeatureExtractor
|
||||
|
||||
|
||||
class LayoutLMv2FeatureExtractionTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=18,
|
||||
apply_ocr=True,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.apply_ocr = apply_ocr
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {"do_resize": self.do_resize, "size": self.size, "apply_ocr": self.apply_ocr}
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_pytesseract
|
||||
class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
|
||||
|
||||
feature_extraction_class = LayoutLMv2FeatureExtractor if is_pytesseract_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.feature_extract_tester = LayoutLMv2FeatureExtractionTester(self)
|
||||
|
||||
@property
|
||||
def feat_extract_dict(self):
|
||||
return self.feature_extract_tester.prepare_feat_extract_dict()
|
||||
|
||||
def test_feat_extract_properties(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
self.assertTrue(hasattr(feature_extractor, "do_resize"))
|
||||
self.assertTrue(hasattr(feature_extractor, "size"))
|
||||
self.assertTrue(hasattr(feature_extractor, "apply_ocr"))
|
||||
|
||||
def test_batch_feature(self):
|
||||
pass
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random PIL images
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoding = feature_extractor(image_inputs[0], return_tensors="pt")
|
||||
self.assertEqual(
|
||||
encoding.pixel_values.shape,
|
||||
(
|
||||
1,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
self.assertIsInstance(encoding.words, list)
|
||||
self.assertIsInstance(encoding.boxes, list)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
1,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
1,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.feature_extract_tester.batch_size,
|
||||
self.feature_extract_tester.num_channels,
|
||||
self.feature_extract_tester.size,
|
||||
self.feature_extract_tester.size,
|
||||
),
|
||||
)
|
||||
|
||||
def test_layoutlmv2_integration_test(self):
|
||||
# with apply_OCR = True
|
||||
feature_extractor = LayoutLMv2FeatureExtractor()
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/fixtures_docvqa", split="test")
|
||||
|
||||
image = Image.open(ds[0]["file"]).convert("RGB")
|
||||
|
||||
encoding = feature_extractor(image, return_tensors="pt")
|
||||
|
||||
self.assertEqual(encoding.pixel_values.shape, (1, 3, 224, 224))
|
||||
self.assertEqual(len(encoding.words), len(encoding.boxes))
|
||||
|
||||
# fmt: off
|
||||
# the words and boxes were obtained with Tesseract 4.1.1
|
||||
expected_words = [['11:14', 'to', '11:39', 'a.m', '11:39', 'to', '11:44', 'a.m.', '11:44', 'a.m.', 'to', '12:25', 'p.m.', '12:25', 'to', '12:58', 'p.m.', '12:58', 'to', '4:00', 'p.m.', '2:00', 'to', '5:00', 'p.m.', 'Coffee', 'Break', 'Coffee', 'will', 'be', 'served', 'for', 'men', 'and', 'women', 'in', 'the', 'lobby', 'adjacent', 'to', 'exhibit', 'area.', 'Please', 'move', 'into', 'exhibit', 'area.', '(Exhibits', 'Open)', 'TRRF', 'GENERAL', 'SESSION', '(PART', '|)', 'Presiding:', 'Lee', 'A.', 'Waller', 'TRRF', 'Vice', 'President', '“Introductory', 'Remarks”', 'Lee', 'A.', 'Waller,', 'TRRF', 'Vice', 'Presi-', 'dent', 'Individual', 'Interviews', 'with', 'TRRF', 'Public', 'Board', 'Members', 'and', 'Sci-', 'entific', 'Advisory', 'Council', 'Mem-', 'bers', 'Conducted', 'by', 'TRRF', 'Treasurer', 'Philip', 'G.', 'Kuehn', 'to', 'get', 'answers', 'which', 'the', 'public', 'refrigerated', 'warehousing', 'industry', 'is', 'looking', 'for.', 'Plus', 'questions', 'from', 'the', 'floor.', 'Dr.', 'Emil', 'M.', 'Mrak,', 'University', 'of', 'Cal-', 'ifornia,', 'Chairman,', 'TRRF', 'Board;', 'Sam', 'R.', 'Cecil,', 'University', 'of', 'Georgia', 'College', 'of', 'Agriculture;', 'Dr.', 'Stanley', 'Charm,', 'Tufts', 'University', 'School', 'of', 'Medicine;', 'Dr.', 'Robert', 'H.', 'Cotton,', 'ITT', 'Continental', 'Baking', 'Company;', 'Dr.', 'Owen', 'Fennema,', 'University', 'of', 'Wis-', 'consin;', 'Dr.', 'Robert', 'E.', 'Hardenburg,', 'USDA.', 'Questions', 'and', 'Answers', 'Exhibits', 'Open', 'Capt.', 'Jack', 'Stoney', 'Room', 'TRRF', 'Scientific', 'Advisory', 'Council', 'Meeting', 'Ballroom', 'Foyer']] # noqa: E231
|
||||
expected_boxes = [[[141, 57, 214, 69], [228, 58, 252, 69], [141, 75, 216, 88], [230, 79, 280, 88], [142, 260, 218, 273], [230, 261, 255, 273], [143, 279, 218, 290], [231, 282, 290, 291], [143, 342, 218, 354], [231, 345, 289, 355], [202, 362, 227, 373], [143, 379, 220, 392], [231, 382, 291, 394], [144, 714, 220, 726], [231, 715, 256, 726], [144, 732, 220, 745], [232, 736, 291, 747], [144, 769, 218, 782], [231, 770, 256, 782], [141, 788, 202, 801], [215, 791, 274, 804], [143, 826, 204, 838], [215, 826, 240, 838], [142, 844, 202, 857], [215, 847, 274, 859], [334, 57, 427, 69], [440, 57, 522, 69], [369, 75, 461, 88], [469, 75, 516, 88], [528, 76, 562, 88], [570, 76, 667, 88], [675, 75, 711, 87], [721, 79, 778, 88], [789, 75, 840, 88], [369, 97, 470, 107], [484, 94, 507, 106], [518, 94, 562, 107], [576, 94, 655, 110], [668, 94, 792, 109], [804, 95, 829, 107], [369, 113, 465, 125], [477, 116, 547, 125], [562, 113, 658, 125], [671, 116, 748, 125], [761, 113, 811, 125], [369, 131, 465, 143], [477, 133, 548, 143], [563, 130, 698, 145], [710, 130, 802, 146], [336, 171, 412, 183], [423, 171, 572, 183], [582, 170, 716, 184], [728, 171, 817, 187], [829, 171, 844, 186], [338, 197, 482, 212], [507, 196, 557, 209], [569, 196, 595, 208], [610, 196, 702, 209], [505, 214, 583, 226], [595, 214, 656, 227], [670, 215, 807, 227], [335, 259, 543, 274], [556, 259, 708, 272], [372, 279, 422, 291], [435, 279, 460, 291], [474, 279, 574, 292], [587, 278, 664, 291], [676, 278, 738, 291], [751, 279, 834, 291], [372, 298, 434, 310], [335, 341, 483, 354], [497, 341, 655, 354], [667, 341, 728, 354], [740, 341, 825, 354], [335, 360, 430, 372], [442, 360, 534, 372], [545, 359, 687, 372], [697, 360, 754, 372], [765, 360, 823, 373], [334, 378, 428, 391], [440, 378, 577, 394], [590, 378, 705, 391], [720, 378, 801, 391], [334, 397, 400, 409], [370, 416, 529, 429], [544, 416, 576, 432], [587, 416, 665, 428], [677, 416, 814, 429], [372, 435, 452, 450], [465, 434, 495, 447], [511, 434, 600, 447], [611, 436, 637, 447], [649, 436, 694, 451], [705, 438, 824, 447], [369, 453, 452, 466], [464, 454, 509, 466], [522, 453, 611, 469], [625, 453, 792, 469], [370, 472, 556, 488], [570, 472, 684, 487], [697, 472, 718, 485], [732, 472, 835, 488], [369, 490, 411, 503], [425, 490, 484, 503], [496, 490, 635, 506], [645, 490, 707, 503], [718, 491, 761, 503], [771, 490, 840, 503], [336, 510, 374, 521], [388, 510, 447, 522], [460, 510, 489, 521], [503, 510, 580, 522], [592, 509, 736, 525], [745, 509, 770, 522], [781, 509, 840, 522], [338, 528, 434, 541], [448, 528, 596, 541], [609, 527, 687, 540], [700, 528, 792, 541], [336, 546, 397, 559], [407, 546, 431, 559], [443, 546, 525, 560], [537, 546, 680, 562], [688, 546, 714, 559], [722, 546, 837, 562], [336, 565, 449, 581], [461, 565, 485, 577], [497, 565, 665, 581], [681, 565, 718, 577], [732, 565, 837, 580], [337, 584, 438, 597], [452, 583, 521, 596], [535, 584, 677, 599], [690, 583, 787, 596], [801, 583, 825, 596], [338, 602, 478, 615], [492, 602, 530, 614], [543, 602, 638, 615], [650, 602, 676, 614], [688, 602, 788, 615], [802, 602, 843, 614], [337, 621, 502, 633], [516, 621, 615, 637], [629, 621, 774, 636], [789, 621, 827, 633], [337, 639, 418, 652], [432, 640, 571, 653], [587, 639, 731, 655], [743, 639, 769, 652], [780, 639, 841, 652], [338, 658, 440, 673], [455, 658, 491, 670], [508, 658, 602, 671], [616, 658, 638, 670], [654, 658, 835, 674], [337, 677, 429, 689], [337, 714, 482, 726], [495, 714, 548, 726], [561, 714, 683, 726], [338, 770, 461, 782], [474, 769, 554, 785], [489, 788, 562, 803], [576, 788, 643, 801], [656, 787, 751, 804], [764, 788, 844, 801], [334, 825, 421, 838], [430, 824, 574, 838], [584, 824, 723, 841], [335, 844, 450, 857], [464, 843, 583, 860], [628, 862, 755, 875], [769, 861, 848, 878]]] # noqa: E231
|
||||
# fmt: on
|
||||
|
||||
self.assertListEqual(encoding.words, expected_words)
|
||||
self.assertListEqual(encoding.boxes, expected_boxes)
|
||||
|
||||
# with apply_OCR = False
|
||||
feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
||||
|
||||
encoding = feature_extractor(image, return_tensors="pt")
|
||||
|
||||
self.assertEqual(
|
||||
encoding.pixel_values.shape,
|
||||
(
|
||||
1,
|
||||
3,
|
||||
224,
|
||||
224,
|
||||
),
|
||||
)
|
532
tests/test_modeling_layoutlmv2.py
Normal file
532
tests/test_modeling_layoutlmv2.py
Normal file
@ -0,0 +1,532 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch LayoutLMv2 model. """
|
||||
|
||||
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.file_utils import is_detectron2_available, is_torch_available
|
||||
from transformers.testing_utils import require_detectron2, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MODEL_MAPPING,
|
||||
LayoutLMv2Config,
|
||||
LayoutLMv2ForQuestionAnswering,
|
||||
LayoutLMv2ForSequenceClassification,
|
||||
LayoutLMv2ForTokenClassification,
|
||||
LayoutLMv2Model,
|
||||
)
|
||||
from transformers.models.layoutlmv2.modeling_layoutlmv2 import LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
if is_detectron2_available():
|
||||
from detectron2.structures.image_list import ImageList
|
||||
|
||||
|
||||
class LayoutLMv2ModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=2,
|
||||
num_channels=3,
|
||||
image_size=4,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=36,
|
||||
num_hidden_layers=3,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
image_feature_pool_shape=[7, 7, 256],
|
||||
coordinate_size=6,
|
||||
shape_size=6,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
range_bbox=1000,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.image_feature_pool_shape = image_feature_pool_shape
|
||||
self.coordinate_size = coordinate_size
|
||||
self.shape_size = shape_size
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.range_bbox = range_bbox
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
bbox = ids_tensor([self.batch_size, self.seq_length, 4], self.range_bbox)
|
||||
# Ensure that bbox is legal
|
||||
for i in range(bbox.shape[0]):
|
||||
for j in range(bbox.shape[1]):
|
||||
if bbox[i, j, 3] < bbox[i, j, 1]:
|
||||
t = bbox[i, j, 3]
|
||||
bbox[i, j, 3] = bbox[i, j, 1]
|
||||
bbox[i, j, 1] = t
|
||||
if bbox[i, j, 2] < bbox[i, j, 0]:
|
||||
t = bbox[i, j, 2]
|
||||
bbox[i, j, 2] = bbox[i, j, 0]
|
||||
bbox[i, j, 0] = t
|
||||
|
||||
image = ImageList(
|
||||
torch.zeros(self.batch_size, self.num_channels, self.image_size, self.image_size), self.image_size
|
||||
)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
|
||||
config = LayoutLMv2Config(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
image_feature_pool_shape=self.image_feature_pool_shape,
|
||||
coordinate_size=self.coordinate_size,
|
||||
shape_size=self.shape_size,
|
||||
)
|
||||
|
||||
# use smaller resnet backbone to make tests faster
|
||||
config.detectron2_config_args["MODEL.RESNETS.DEPTH"] = 18
|
||||
config.detectron2_config_args["MODEL.RESNETS.RES2_OUT_CHANNELS"] = 64
|
||||
config.detectron2_config_args["MODEL.RESNETS.NUM_GROUPS"] = 1
|
||||
|
||||
return config, input_ids, bbox, image, token_type_ids, input_mask, sequence_labels, token_labels
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, bbox, image, token_type_ids, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
model = LayoutLMv2Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(input_ids, bbox=bbox, image=image, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, bbox=bbox, image=image, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, bbox=bbox, image=image)
|
||||
|
||||
# LayoutLMv2 has a different expected sequence length, namely also visual tokens are added
|
||||
expected_seq_len = self.seq_length + self.image_feature_pool_shape[0] * self.image_feature_pool_shape[1]
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_for_sequence_classification(
|
||||
self, config, input_ids, bbox, image, token_type_ids, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = LayoutLMv2ForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
bbox=bbox,
|
||||
image=image,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_for_token_classification(
|
||||
self, config, input_ids, bbox, image, token_type_ids, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = LayoutLMv2ForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
bbox=bbox,
|
||||
image=image,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=token_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_for_question_answering(
|
||||
self, config, input_ids, bbox, image, token_type_ids, input_mask, sequence_labels, token_labels
|
||||
):
|
||||
model = LayoutLMv2ForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
bbox=bbox,
|
||||
image=image,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
bbox,
|
||||
image,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"bbox": bbox,
|
||||
"image": image,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_mask": input_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_detectron2
|
||||
class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
LayoutLMv2Model,
|
||||
LayoutLMv2ForSequenceClassification,
|
||||
LayoutLMv2ForTokenClassification,
|
||||
LayoutLMv2ForQuestionAnswering,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = LayoutLMv2ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=LayoutLMv2Config, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
if isinstance(base_class, tuple):
|
||||
base_class = base_class[0]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
# make a copy of model class to not break future tests
|
||||
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
|
||||
class CopyClass(model_class):
|
||||
pass
|
||||
|
||||
model_class_copy = CopyClass
|
||||
|
||||
# make sure that all keys are expected for test
|
||||
model_class_copy._keys_to_ignore_on_load_missing = []
|
||||
|
||||
# make init deterministic, but make sure that
|
||||
# non-initialized weights throw errors nevertheless
|
||||
model_class_copy._init_weights = self._mock_init_weights
|
||||
|
||||
model = base_class(config)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
# this will often delete a single weight of a multi-weight module
|
||||
# to test an edge case
|
||||
random_key_to_del = random.choice(list(state_dict.keys()))
|
||||
del state_dict[random_key_to_del]
|
||||
|
||||
# check that certain keys didn't get saved with the model
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
|
||||
|
||||
model_fast_init = model_class_copy.from_pretrained(tmpdirname)
|
||||
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
||||
|
||||
for key in model_fast_init.state_dict().keys():
|
||||
if key == "layoutlmv2.visual_segment_embedding":
|
||||
# we skip the visual segment embedding as it has a custom initialization scheme
|
||||
continue
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
# LayoutLMv2 has a different expected sequence length
|
||||
expected_seq_len = (
|
||||
self.model_tester.seq_length
|
||||
+ self.model_tester.image_feature_pool_shape[0] * self.model_tester.image_feature_pool_shape[1]
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, expected_seq_len, expected_seq_len],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, expected_seq_len, expected_seq_len],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
# LayoutLMv2 has a different expected sequence length
|
||||
expected_seq_len = (
|
||||
self.model_tester.seq_length
|
||||
+ self.model_tester.image_feature_pool_shape[0] * self.model_tester.image_feature_pool_shape[1]
|
||||
)
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[expected_seq_len, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = LayoutLMv2Model.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if "backbone" in name or "visual_segment_embedding" in name:
|
||||
continue
|
||||
|
||||
if param.requires_grad:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
|
||||
def prepare_layoutlmv2_batch_inputs():
|
||||
# Here we prepare a batch of 2 sequences to test a LayoutLMv2 forward pass on:
|
||||
# fmt: off
|
||||
input_ids = torch.tensor([[101,1019,1014,1016,1037,12849,4747,1004,14246,2278,5439,4524,5002,2930,2193,2930,4341,3208,1005,1055,2171,2848,11300,3531,102],[101,4070,4034,7020,1024,3058,1015,1013,2861,1013,6070,19274,2772,6205,27814,16147,16147,4343,2047,10283,10969,14389,1012,2338,102]],device=torch_device) # noqa: E231
|
||||
bbox = torch.tensor([[[0,0,0,0],[423,237,440,251],[427,272,441,287],[419,115,437,129],[961,885,992,912],[256,38,330,58],[256,38,330,58],[336,42,353,57],[360,39,401,56],[360,39,401,56],[411,39,471,59],[479,41,528,59],[533,39,630,60],[67,113,134,131],[141,115,209,132],[68,149,133,166],[141,149,187,164],[195,148,287,165],[195,148,287,165],[195,148,287,165],[295,148,349,165],[441,149,492,166],[497,149,546,164],[64,201,125,218],[1000,1000,1000,1000]],[[0,0,0,0],[662,150,754,166],[665,199,742,211],[519,213,554,228],[519,213,554,228],[134,433,187,454],[130,467,204,480],[130,467,204,480],[130,467,204,480],[130,467,204,480],[130,467,204,480],[314,469,376,482],[504,684,582,706],[941,825,973,900],[941,825,973,900],[941,825,973,900],[941,825,973,900],[610,749,652,765],[130,659,168,672],[176,657,237,672],[238,657,312,672],[443,653,628,672],[443,653,628,672],[716,301,825,317],[1000,1000,1000,1000]]],device=torch_device) # noqa: E231
|
||||
image = ImageList(torch.randn((2,3,224,224)), image_sizes=[(224,224), (224,224)]) # noqa: E231
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],],device=torch_device) # noqa: E231
|
||||
token_type_ids = torch.tensor([[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]],device=torch_device) # noqa: E231
|
||||
# fmt: on
|
||||
|
||||
return input_ids, bbox, image, attention_mask, token_type_ids
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_detectron2
|
||||
class LayoutLMv2ModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = LayoutLMv2Model.from_pretrained("microsoft/layoutlmv2-base-uncased").to(torch_device)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
bbox,
|
||||
image,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
) = prepare_layoutlmv2_batch_inputs()
|
||||
|
||||
# forward pass
|
||||
outputs = model(
|
||||
input_ids=input_ids,
|
||||
bbox=bbox,
|
||||
image=image,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
|
||||
# verify the sequence output
|
||||
expected_shape = torch.Size(
|
||||
(
|
||||
2,
|
||||
input_ids.shape[1]
|
||||
+ model.config.image_feature_pool_shape[0] * model.config.image_feature_pool_shape[1],
|
||||
model.config.hidden_size,
|
||||
)
|
||||
)
|
||||
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[-0.1087, 0.0727, -0.3075], [0.0799, -0.0427, -0.0751], [-0.0367, 0.0480, -0.1358]], device=torch_device
|
||||
)
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-3))
|
||||
|
||||
# verify the pooled output
|
||||
expected_shape = torch.Size((2, model.config.hidden_size))
|
||||
self.assertEqual(outputs.pooler_output.shape, expected_shape)
|
429
tests/test_processor_layoutlmv2.py
Normal file
429
tests/test_processor_layoutlmv2.py
Normal file
@ -0,0 +1,429 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
|
||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, cached_property, is_pytesseract_available
|
||||
from transformers.models.layoutlmv2 import LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast
|
||||
from transformers.models.layoutlmv2.tokenization_layoutlmv2 import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_pytesseract, require_tokenizers, require_torch, slow
|
||||
|
||||
|
||||
if is_pytesseract_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import LayoutLMv2FeatureExtractor, LayoutLMv2Processor
|
||||
|
||||
|
||||
@require_pytesseract
|
||||
@require_tokenizers
|
||||
class LayoutLMv2ProcessorTest(unittest.TestCase):
|
||||
tokenizer_class = LayoutLMv2Tokenizer
|
||||
rust_tokenizer_class = LayoutLMv2TokenizerFast
|
||||
|
||||
def setUp(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[PAD]",
|
||||
"[MASK]",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
"wa",
|
||||
"un",
|
||||
"runn",
|
||||
"##ing",
|
||||
",",
|
||||
"low",
|
||||
"lowest",
|
||||
]
|
||||
|
||||
feature_extractor_map = {
|
||||
"do_resize": True,
|
||||
"size": 224,
|
||||
"apply_ocr": True,
|
||||
}
|
||||
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
|
||||
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(feature_extractor_map) + "\n")
|
||||
|
||||
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
|
||||
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
|
||||
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]:
|
||||
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
|
||||
|
||||
def get_feature_extractor(self, **kwargs):
|
||||
return LayoutLMv2FeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def test_save_load_pretrained_default(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
processor = LayoutLMv2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
processor = LayoutLMv2Processor.from_pretrained(self.tmpdirname)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, (LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast))
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = LayoutLMv2Processor(feature_extractor=self.get_feature_extractor(), tokenizer=self.get_tokenizer())
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
# slow tokenizer
|
||||
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
||||
feature_extractor_add_kwargs = self.get_feature_extractor(do_resize=False, size=30)
|
||||
|
||||
processor = LayoutLMv2Processor.from_pretrained(
|
||||
self.tmpdirname, use_fast=False, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30
|
||||
)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, LayoutLMv2Tokenizer)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
|
||||
|
||||
# fast tokenizer
|
||||
tokenizer_add_kwargs = self.get_rust_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
||||
feature_extractor_add_kwargs = self.get_feature_extractor(do_resize=False, size=30)
|
||||
|
||||
processor = LayoutLMv2Processor.from_pretrained(
|
||||
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30
|
||||
)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, LayoutLMv2TokenizerFast)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
|
||||
|
||||
|
||||
# different use cases tests
|
||||
@require_torch
|
||||
@require_pytesseract
|
||||
class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def get_images(self):
|
||||
# we verify our implementation on 2 document images from the DocVQA dataset
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/fixtures_docvqa", split="test")
|
||||
|
||||
image_1 = Image.open(ds[0]["file"]).convert("RGB")
|
||||
image_2 = Image.open(ds[1]["file"]).convert("RGB")
|
||||
|
||||
return image_1, image_2
|
||||
|
||||
@cached_property
|
||||
def get_tokenizers(self):
|
||||
slow_tokenizer = LayoutLMv2Tokenizer.from_pretrained("microsoft/layoutlmv2-base-uncased")
|
||||
fast_tokenizer = LayoutLMv2TokenizerFast.from_pretrained("microsoft/layoutlmv2-base-uncased")
|
||||
return [slow_tokenizer, fast_tokenizer]
|
||||
|
||||
@slow
|
||||
def test_processor_case_1(self):
|
||||
# case 1: document image classification (training, inference) + token classification (inference), apply_ocr = True
|
||||
|
||||
feature_extractor = LayoutLMv2FeatureExtractor()
|
||||
tokenizers = self.get_tokenizers
|
||||
images = self.get_images
|
||||
|
||||
for tokenizer in tokenizers:
|
||||
processor = LayoutLMv2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
# not batched
|
||||
input_feat_extract = feature_extractor(images[0], return_tensors="pt")
|
||||
input_processor = processor(images[0], return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify image
|
||||
self.assertAlmostEqual(
|
||||
input_feat_extract["pixel_values"].sum(), input_processor["image"].sum(), delta=1e-2
|
||||
)
|
||||
|
||||
# verify input_ids
|
||||
# fmt: off
|
||||
expected_decoding = "[CLS] 11 : 14 to 11 : 39 a. m 11 : 39 to 11 : 44 a. m. 11 : 44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president “ introductory remarks ” lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from the floor. dr. emil m. mrak, university of cal - ifornia, chairman, trrf board ; sam r. cecil, university of georgia college of agriculture ; dr. stanley charm, tufts university school of medicine ; dr. robert h. cotton, itt continental baking company ; dr. owen fennema, university of wis - consin ; dr. robert e. hardenburg, usda. questions and answers exhibits open capt. jack stoney room trrf scientific advisory council meeting ballroom foyer [SEP]" # noqa: E231
|
||||
# fmt: on
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
input_feat_extract = feature_extractor(images, return_tensors="pt")
|
||||
input_processor = processor(images, padding=True, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify images
|
||||
self.assertAlmostEqual(
|
||||
input_feat_extract["pixel_values"].sum(), input_processor["image"].sum(), delta=1e-2
|
||||
)
|
||||
|
||||
# verify input_ids
|
||||
# fmt: off
|
||||
expected_decoding = "[CLS] 7 itc limited report and accounts 2013 itc ’ s brands : an asset for the nation the consumer needs and aspirations they fulfil, the benefit they generate for millions across itc ’ s value chains, the future - ready capabilities that support them, and the value that they create for the country, have made itc ’ s brands national assets, adding to india ’ s competitiveness. it is itc ’ s aspiration to be the no 1 fmcg player in the country, driven by its new fmcg businesses. a recent nielsen report has highlighted that itc's new fmcg businesses are the fastest growing among the top consumer goods companies operating in india. itc takes justifiable pride that, along with generating economic value, these celebrated indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. di wills * ; love delightfully soft skin? aia ans source : https : / / www. industrydocuments. ucsf. edu / docs / snbx0223 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]" # noqa: E231
|
||||
# fmt: on
|
||||
decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
@slow
|
||||
def test_processor_case_2(self):
|
||||
# case 2: document image classification (training, inference) + token classification (inference), apply_ocr=False
|
||||
|
||||
feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
||||
tokenizers = self.get_tokenizers
|
||||
images = self.get_images
|
||||
|
||||
for tokenizer in tokenizers:
|
||||
processor = LayoutLMv2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
# not batched
|
||||
words = ["hello", "world"]
|
||||
boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
input_processor = processor(images[0], words, boxes=boxes, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["input_ids", "bbox", "token_type_ids", "attention_mask", "image"]
|
||||
actual_keys = list(input_processor.keys())
|
||||
for key in expected_keys:
|
||||
self.assertIn(key, actual_keys)
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "[CLS] hello world [SEP]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
words = [["hello", "world"], ["my", "name", "is", "niels"]]
|
||||
boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
|
||||
input_processor = processor(images, words, boxes=boxes, padding=True, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "[CLS] hello world [SEP] [PAD] [PAD] [PAD]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
expected_bbox = [
|
||||
[0, 0, 0, 0],
|
||||
[3, 2, 5, 1],
|
||||
[6, 7, 4, 2],
|
||||
[3, 9, 2, 4],
|
||||
[1, 1, 2, 3],
|
||||
[1, 1, 2, 3],
|
||||
[1000, 1000, 1000, 1000],
|
||||
]
|
||||
self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox)
|
||||
|
||||
@slow
|
||||
def test_processor_case_3(self):
|
||||
# case 3: token classification (training), apply_ocr=False
|
||||
|
||||
feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
||||
tokenizers = self.get_tokenizers
|
||||
images = self.get_images
|
||||
|
||||
for tokenizer in tokenizers:
|
||||
processor = LayoutLMv2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
# not batched
|
||||
words = ["weirdly", "world"]
|
||||
boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
word_labels = [1, 2]
|
||||
input_processor = processor(images[0], words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels", "token_type_ids"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "[CLS] weirdly world [SEP]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify labels
|
||||
expected_labels = [-100, 1, -100, 2, -100]
|
||||
self.assertListEqual(input_processor.labels.squeeze().tolist(), expected_labels)
|
||||
|
||||
# batched
|
||||
words = [["hello", "world"], ["my", "name", "is", "niels"]]
|
||||
boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
|
||||
word_labels = [[1, 2], [6, 3, 10, 2]]
|
||||
input_processor = processor(
|
||||
images, words, boxes=boxes, word_labels=word_labels, padding=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels", "token_type_ids"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "[CLS] my name is niels [SEP]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
expected_bbox = [
|
||||
[0, 0, 0, 0],
|
||||
[3, 2, 5, 1],
|
||||
[6, 7, 4, 2],
|
||||
[3, 9, 2, 4],
|
||||
[1, 1, 2, 3],
|
||||
[1, 1, 2, 3],
|
||||
[1000, 1000, 1000, 1000],
|
||||
]
|
||||
self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox)
|
||||
|
||||
# verify labels
|
||||
expected_labels = [-100, 6, 3, 10, 2, -100, -100]
|
||||
self.assertListEqual(input_processor.labels[1].tolist(), expected_labels)
|
||||
|
||||
@slow
|
||||
def test_processor_case_4(self):
|
||||
# case 4: visual question answering (inference), apply_ocr=True
|
||||
|
||||
feature_extractor = LayoutLMv2FeatureExtractor()
|
||||
tokenizers = self.get_tokenizers
|
||||
images = self.get_images
|
||||
|
||||
for tokenizer in tokenizers:
|
||||
processor = LayoutLMv2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
# not batched
|
||||
question = "What's his name?"
|
||||
input_processor = processor(images[0], question, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
# fmt: off
|
||||
expected_decoding = "[CLS] what's his name? [SEP] 11 : 14 to 11 : 39 a. m 11 : 39 to 11 : 44 a. m. 11 : 44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president “ introductory remarks ” lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from the floor. dr. emil m. mrak, university of cal - ifornia, chairman, trrf board ; sam r. cecil, university of georgia college of agriculture ; dr. stanley charm, tufts university school of medicine ; dr. robert h. cotton, itt continental baking company ; dr. owen fennema, university of wis - consin ; dr. robert e. hardenburg, usda. questions and answers exhibits open capt. jack stoney room trrf scientific advisory council meeting ballroom foyer [SEP]" # noqa: E231
|
||||
# fmt: on
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
questions = ["How old is he?", "what's the time"]
|
||||
input_processor = processor(
|
||||
images, questions, padding="max_length", max_length=20, truncation=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "[CLS] what's the time [SEP] 7 itc limited report and accounts 2013 itc ’ s [SEP]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
# fmt: off
|
||||
expected_bbox = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [1000, 1000, 1000, 1000], [0, 45, 67, 80], [72, 56, 109, 67], [72, 56, 109, 67], [116, 56, 189, 67], [198, 59, 253, 66], [257, 59, 285, 66], [289, 59, 365, 66], [372, 59, 407, 66], [74, 136, 161, 158], [74, 136, 161, 158], [74, 136, 161, 158], [74, 136, 161, 158], [1000, 1000, 1000, 1000]] # noqa: E231
|
||||
# fmt: on
|
||||
self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox)
|
||||
|
||||
@slow
|
||||
def test_processor_case_5(self):
|
||||
# case 5: visual question answering (inference), apply_ocr=False
|
||||
|
||||
feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
|
||||
tokenizers = self.get_tokenizers
|
||||
images = self.get_images
|
||||
|
||||
for tokenizer in tokenizers:
|
||||
processor = LayoutLMv2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
# not batched
|
||||
question = "What's his name?"
|
||||
words = ["hello", "world"]
|
||||
boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
input_processor = processor(images[0], question, words, boxes, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "[CLS] what's his name? [SEP] hello world [SEP]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids.squeeze().tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# batched
|
||||
questions = ["How old is he?", "what's the time"]
|
||||
words = [["hello", "world"], ["my", "name", "is", "niels"]]
|
||||
boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
|
||||
input_processor = processor(images, questions, words, boxes, padding=True, return_tensors="pt")
|
||||
|
||||
# verify keys
|
||||
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
|
||||
actual_keys = sorted(list(input_processor.keys()))
|
||||
self.assertListEqual(actual_keys, expected_keys)
|
||||
|
||||
# verify input_ids
|
||||
expected_decoding = "[CLS] how old is he? [SEP] hello world [SEP] [PAD] [PAD] [PAD]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[0].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
expected_decoding = "[CLS] what's the time [SEP] my name is niels [SEP]"
|
||||
decoding = tokenizer.decode(input_processor.input_ids[1].tolist())
|
||||
self.assertSequenceEqual(decoding, expected_decoding)
|
||||
|
||||
# verify bbox
|
||||
expected_bbox = [[6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3], [1, 1, 2, 3], [1000, 1000, 1000, 1000]]
|
||||
self.assertListEqual(input_processor.bbox[1].tolist()[-5:], expected_bbox)
|
1920
tests/test_tokenization_layoutlmv2.py
Normal file
1920
tests/test_tokenization_layoutlmv2.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user