mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Add splinter (#12955)
* splinter template * initialize splinter classes * Splinter Tokenizer * splinter.rst * tokenization fixes * Documentation & some minor variable name changes * bug fix (added back question_token_id to config) + variable names * Minor bug fixes + variable name changes * Fix Splinter references after merge with new transformers * changes after running make style & quality * Fix documentation unindent * Fix doc indentation in tokenization_splinter * Fix also SplinterTokenizerFast * Add Splinter to index.rst and README * Fixdouble whitespace from index.rst * Fixed index.rst with 'make fix-copies' * Update docs/source/model_doc/splinter.rst Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update docs/source/model_doc/splinter.rst Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update docs/source/model_doc/splinter.rst Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update docs/source/model_doc/splinter.rst Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update src/transformers/models/splinter/__init__.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Added "copied from BERT" comments * Removing unnexessary code from modeling_splinter * Update README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/splinter/configuration_splinter.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * Remove references to TF modeling from splinter * Update src/transformers/models/splinter/modeling_splinter.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Remove unnecessary check * Update src/transformers/models/splinter/modeling_splinter.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Add differences between Splinter and Bert tokenizers * Update src/transformers/models/splinter/modeling_splinter.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/splinter/tokenization_splinter_fast.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Remove unnecessary check * Doc formatting * Update src/transformers/models/splinter/tokenization_splinter.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/splinter/tokenization_splinter.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * bug fix: remove load_tf_weights attribute * Some minor quality changes * Update docs/source/model_doc/splinter.rst Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/models/splinter/configuration_splinter.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Change FullyConnectedLayer to SplinterFullyConnectedLayer * Variable naming * Reove gather_positions function * Remove ClassificationHead as it's outdated * Update src/transformers/models/splinter/modeling_splinter.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Remove hardcoded 102 token id * Minor style change * Added "tau" organization to all model identifiers & URLS * Added tau to the tests as well * Copy-from comments * Removed all unnecessary classes (e.g. SplinterForMaskedLM) * Running make fix-copies * Bug fix: Further removed unnecessary classes * Add Splinter to AutoTokenization * Add an integration test for Splinter * Removed initialize_new_qass from config - It will be done through different checkpoints * Removed `initialize_new_qass` from documentation as well * Added new checkpoint names (`tau/splinter-base-qass` and same for large) in the code * Minor change to test * SplinterTokenizer now doesn't abstract from BertTokenizer * SplinterTokenizerFast also dosn't abstract from Bert * style and quality * bug fix: import ing torch in tests only if it's available * Auto mappings * Changed copyrights in Splinter's files * Update src/transformers/models/splinter/configuration_splinter.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: yuvalkirstain <kirstain.yuval@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
6626d8a62f
commit
439a43b6b4
@ -263,6 +263,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
|||||||
1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||||
1. **[RoFormer](https://huggingface.co/transformers/model_doc/roformer.html)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
1. **[RoFormer](https://huggingface.co/transformers/model_doc/roformer.html)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||||
1. **[SpeechToTextTransformer](https://huggingface.co/transformers/model_doc/speech_to_text.html)** (from Facebook), released together with the paper [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
|
1. **[SpeechToTextTransformer](https://huggingface.co/transformers/model_doc/speech_to_text.html)** (from Facebook), released together with the paper [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
|
||||||
|
1. **[Splinter](https://huggingface.co/transformers/model_doc/master/splinter.html)** (from Tel Aviv University), released together with the paper [Few-Shot Question Answering by Pretraining Span Selection](https://arxiv.org/abs/2101.00438) by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy.
|
||||||
1. **[SqueezeBert](https://huggingface.co/transformers/model_doc/squeezebert.html)** released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
|
1. **[SqueezeBert](https://huggingface.co/transformers/model_doc/squeezebert.html)** released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
|
||||||
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||||
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
||||||
|
@ -257,41 +257,44 @@ Supported models
|
|||||||
53. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
|
53. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
|
||||||
`fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun
|
`fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun
|
||||||
Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
|
Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
|
||||||
54. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
54. `Splinter <https://huggingface.co/transformers/model_doc/master/splinter.html>`__ (from Tel Aviv University),
|
||||||
|
released together with the paper `Few-Shot Question Answering by Pretraining Span Selection
|
||||||
|
<https://arxiv.org/abs/2101.00438>`__ by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy.
|
||||||
|
55. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||||
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
|
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
|
||||||
Krishna, and Kurt W. Keutzer.
|
Krishna, and Kurt W. Keutzer.
|
||||||
55. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
56. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||||
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
|
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
|
||||||
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||||
56. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
57. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||||
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
|
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
|
||||||
Francesco Piccinno and Julian Martin Eisenschlos.
|
Francesco Piccinno and Julian Martin Eisenschlos.
|
||||||
57. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
58. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||||
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
||||||
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||||
58. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
|
59. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
|
||||||
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
|
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
|
||||||
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
|
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
|
||||||
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||||
59. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
|
60. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
|
||||||
Performant Baseline for Vision and Language <https://arxiv.org/pdf/1908.03557>`__ by Liunian Harold Li, Mark
|
Performant Baseline for Vision and Language <https://arxiv.org/pdf/1908.03557>`__ by Liunian Harold Li, Mark
|
||||||
Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
|
Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
|
||||||
60. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
61. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||||
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
|
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
|
||||||
Zhou, Abdelrahman Mohamed, Michael Auli.
|
Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||||
61. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
62. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||||
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
||||||
62. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
63. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||||
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
||||||
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||||
63. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
64. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||||
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
||||||
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
||||||
Zettlemoyer and Veselin Stoyanov.
|
Zettlemoyer and Veselin Stoyanov.
|
||||||
64. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
65. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||||
65. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
66. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||||
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
|
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
|
||||||
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||||
|
|
||||||
@ -413,6 +416,8 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| Speech2Text | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| Speech2Text | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| Splinter | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| T5 | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| T5 | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
@ -571,6 +576,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
model_doc/roberta
|
model_doc/roberta
|
||||||
model_doc/roformer
|
model_doc/roformer
|
||||||
model_doc/speech_to_text
|
model_doc/speech_to_text
|
||||||
|
model_doc/splinter
|
||||||
model_doc/squeezebert
|
model_doc/squeezebert
|
||||||
model_doc/t5
|
model_doc/t5
|
||||||
model_doc/tapas
|
model_doc/tapas
|
||||||
|
87
docs/source/model_doc/splinter.rst
Normal file
87
docs/source/model_doc/splinter.rst
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
..
|
||||||
|
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
Splinter
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
Overview
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
The Splinter model was proposed in `Few-Shot Question Answering by Pretraining Span Selection
|
||||||
|
<https://arxiv.org/abs/2101.00438>`__ by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy. Splinter
|
||||||
|
is an encoder-only transformer (similar to BERT) pretrained using the recurring span selection task on a large corpus
|
||||||
|
comprising Wikipedia and the Toronto Book Corpus.
|
||||||
|
|
||||||
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
|
In several question answering benchmarks, pretrained models have reached human parity through fine-tuning on an order
|
||||||
|
of 100,000 annotated questions and answers. We explore the more realistic few-shot setting, where only a few hundred
|
||||||
|
training examples are available, and observe that standard models perform poorly, highlighting the discrepancy between
|
||||||
|
current pretraining objectives and question answering. We propose a new pretraining scheme tailored for question
|
||||||
|
answering: recurring span selection. Given a passage with multiple sets of recurring spans, we mask in each set all
|
||||||
|
recurring spans but one, and ask the model to select the correct span in the passage for each masked span. Masked spans
|
||||||
|
are replaced with a special token, viewed as a question representation, that is later used during fine-tuning to select
|
||||||
|
the answer span. The resulting model obtains surprisingly good results on multiple benchmarks (e.g., 72.7 F1 on SQuAD
|
||||||
|
with only 128 training examples), while maintaining competitive performance in the high-resource setting.
|
||||||
|
|
||||||
|
Tips:
|
||||||
|
|
||||||
|
- Splinter was trained to predict answers spans conditioned on a special [QUESTION] token. These tokens contextualize
|
||||||
|
to question representations which are used to predict the answers. This layer is called QASS, and is the default
|
||||||
|
behaviour in the :class:`~transformers.SplinterForQuestionAnswering` class. Therefore:
|
||||||
|
- Use :class:`~transformers.SplinterTokenizer` (rather than :class:`~transformers.BertTokenizer`), as it already
|
||||||
|
contains this special token. Also, its default behavior is to use this token when two sequences are given (for
|
||||||
|
example, in the `run_qa.py` script).
|
||||||
|
- If you plan on using Splinter outside `run_qa.py`, please keep in mind the question token - it might be important for
|
||||||
|
the success of your model, especially in a few-shot setting.
|
||||||
|
- Please note there are two different checkpoints for each size of Splinter. Both are basically the same, except that
|
||||||
|
one also has the pretrained wights of the QASS layer (`tau/splinter-base-qass` and `tau/splinter-large-qass`) and one
|
||||||
|
doesn't (`tau/splinter-base` and `tau/splinter-large`). This is done to support randomly initializing this layer at
|
||||||
|
fine-tuning, as it is shown to yield better results for some cases in the paper.
|
||||||
|
|
||||||
|
This model was contributed by `yuvalkirstain <https://huggingface.co/yuvalkirstain>`__ and `oriram
|
||||||
|
<https://huggingface.co/oriram>`__. The original code can be found `here <https://github.com/oriram/splinter>`__.
|
||||||
|
|
||||||
|
SplinterConfig
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.SplinterConfig
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
SplinterTokenizer
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.SplinterTokenizer
|
||||||
|
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||||
|
create_token_type_ids_from_sequences, save_vocabulary
|
||||||
|
|
||||||
|
|
||||||
|
SplinterTokenizerFast
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.SplinterTokenizerFast
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
SplinterModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.SplinterModel
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
SplinterForQuestionAnswering
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.SplinterForQuestionAnswering
|
||||||
|
:members: forward
|
@ -233,6 +233,7 @@ _import_structure = {
|
|||||||
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
"Speech2TextConfig",
|
"Speech2TextConfig",
|
||||||
],
|
],
|
||||||
|
"models.splinter": ["SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SplinterConfig", "SplinterTokenizer"],
|
||||||
"models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"],
|
"models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"],
|
||||||
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
|
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
|
||||||
"models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"],
|
"models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"],
|
||||||
@ -369,6 +370,7 @@ if is_tokenizers_available():
|
|||||||
_import_structure["models.rembert"].append("RemBertTokenizerFast")
|
_import_structure["models.rembert"].append("RemBertTokenizerFast")
|
||||||
_import_structure["models.retribert"].append("RetriBertTokenizerFast")
|
_import_structure["models.retribert"].append("RetriBertTokenizerFast")
|
||||||
_import_structure["models.roberta"].append("RobertaTokenizerFast")
|
_import_structure["models.roberta"].append("RobertaTokenizerFast")
|
||||||
|
_import_structure["models.splinter"].append("SplinterTokenizerFast")
|
||||||
_import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast")
|
_import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast")
|
||||||
_import_structure["models.t5"].append("T5TokenizerFast")
|
_import_structure["models.t5"].append("T5TokenizerFast")
|
||||||
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizerFast")
|
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizerFast")
|
||||||
@ -1046,6 +1048,15 @@ if is_torch_available():
|
|||||||
"Speech2TextPreTrainedModel",
|
"Speech2TextPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.splinter"].extend(
|
||||||
|
[
|
||||||
|
"SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"SplinterForQuestionAnswering",
|
||||||
|
"SplinterLayer",
|
||||||
|
"SplinterModel",
|
||||||
|
"SplinterPreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.squeezebert"].extend(
|
_import_structure["models.squeezebert"].extend(
|
||||||
[
|
[
|
||||||
"SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@ -1914,6 +1925,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
|
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
|
||||||
from .models.roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerTokenizer
|
from .models.roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerTokenizer
|
||||||
from .models.speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
|
from .models.speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
|
||||||
|
from .models.splinter import SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP, SplinterConfig, SplinterTokenizer
|
||||||
from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer
|
from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer
|
||||||
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||||
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
|
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
|
||||||
@ -2045,6 +2057,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.retribert import RetriBertTokenizerFast
|
from .models.retribert import RetriBertTokenizerFast
|
||||||
from .models.roberta import RobertaTokenizerFast
|
from .models.roberta import RobertaTokenizerFast
|
||||||
from .models.roformer import RoFormerTokenizerFast
|
from .models.roformer import RoFormerTokenizerFast
|
||||||
|
from .models.splinter import SplinterTokenizerFast
|
||||||
from .models.squeezebert import SqueezeBertTokenizerFast
|
from .models.squeezebert import SqueezeBertTokenizerFast
|
||||||
from .models.t5 import T5TokenizerFast
|
from .models.t5 import T5TokenizerFast
|
||||||
from .models.xlm_roberta import XLMRobertaTokenizerFast
|
from .models.xlm_roberta import XLMRobertaTokenizerFast
|
||||||
@ -2602,6 +2615,13 @@ if TYPE_CHECKING:
|
|||||||
Speech2TextModel,
|
Speech2TextModel,
|
||||||
Speech2TextPreTrainedModel,
|
Speech2TextPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.splinter import (
|
||||||
|
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
SplinterForQuestionAnswering,
|
||||||
|
SplinterLayer,
|
||||||
|
SplinterModel,
|
||||||
|
SplinterPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.squeezebert import (
|
from .models.squeezebert import (
|
||||||
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
SqueezeBertForMaskedLM,
|
SqueezeBertForMaskedLM,
|
||||||
|
@ -108,6 +108,56 @@ class BertConverter(Converter):
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class SplinterConverter(Converter):
|
||||||
|
def converted(self) -> Tokenizer:
|
||||||
|
vocab = self.original_tokenizer.vocab
|
||||||
|
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
||||||
|
|
||||||
|
tokenize_chinese_chars = False
|
||||||
|
strip_accents = False
|
||||||
|
do_lower_case = False
|
||||||
|
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
||||||
|
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
||||||
|
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
||||||
|
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
||||||
|
|
||||||
|
tokenizer.normalizer = normalizers.BertNormalizer(
|
||||||
|
clean_text=True,
|
||||||
|
handle_chinese_chars=tokenize_chinese_chars,
|
||||||
|
strip_accents=strip_accents,
|
||||||
|
lowercase=do_lower_case,
|
||||||
|
)
|
||||||
|
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
||||||
|
|
||||||
|
cls = str(self.original_tokenizer.cls_token)
|
||||||
|
sep = str(self.original_tokenizer.sep_token)
|
||||||
|
question = str(self.original_tokenizer.question_token)
|
||||||
|
dot = "."
|
||||||
|
cls_token_id = self.original_tokenizer.cls_token_id
|
||||||
|
sep_token_id = self.original_tokenizer.sep_token_id
|
||||||
|
question_token_id = self.original_tokenizer.question_token_id
|
||||||
|
dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
|
||||||
|
|
||||||
|
if self.original_tokenizer.padding_side == "right":
|
||||||
|
pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
|
||||||
|
else:
|
||||||
|
pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
|
||||||
|
|
||||||
|
tokenizer.post_processor = processors.TemplateProcessing(
|
||||||
|
single=f"{cls}:0 $A:0 {sep}:0",
|
||||||
|
pair=pair,
|
||||||
|
special_tokens=[
|
||||||
|
(cls, cls_token_id),
|
||||||
|
(sep, sep_token_id),
|
||||||
|
(question, question_token_id),
|
||||||
|
(dot, dot_token_id),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
class FunnelConverter(Converter):
|
class FunnelConverter(Converter):
|
||||||
def converted(self) -> Tokenizer:
|
def converted(self) -> Tokenizer:
|
||||||
vocab = self.original_tokenizer.vocab
|
vocab = self.original_tokenizer.vocab
|
||||||
@ -829,6 +879,7 @@ SLOW_TO_FAST_CONVERTERS = {
|
|||||||
"T5Tokenizer": T5Converter,
|
"T5Tokenizer": T5Converter,
|
||||||
"XLMRobertaTokenizer": XLMRobertaConverter,
|
"XLMRobertaTokenizer": XLMRobertaConverter,
|
||||||
"XLNetTokenizer": XLNetConverter,
|
"XLNetTokenizer": XLNetConverter,
|
||||||
|
"SplinterTokenizer": SplinterConverter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,6 +78,7 @@ from . import (
|
|||||||
roberta,
|
roberta,
|
||||||
roformer,
|
roformer,
|
||||||
speech_to_text,
|
speech_to_text,
|
||||||
|
splinter,
|
||||||
squeezebert,
|
squeezebert,
|
||||||
t5,
|
t5,
|
||||||
tapas,
|
tapas,
|
||||||
|
@ -87,6 +87,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("layoutlm", "LayoutLMConfig"),
|
("layoutlm", "LayoutLMConfig"),
|
||||||
("rag", "RagConfig"),
|
("rag", "RagConfig"),
|
||||||
("tapas", "TapasConfig"),
|
("tapas", "TapasConfig"),
|
||||||
|
("splinter", "SplinterConfig"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -147,6 +148,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
|||||||
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("tapas", "TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("ibert", "IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
("hubert", "HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
|
("splinter", "SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -222,6 +224,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("bert-japanese", "BertJapanese"),
|
("bert-japanese", "BertJapanese"),
|
||||||
("byt5", "ByT5"),
|
("byt5", "ByT5"),
|
||||||
("mbart50", "mBART-50"),
|
("mbart50", "mBART-50"),
|
||||||
|
("splinter", "Splinter"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -88,6 +88,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("mpnet", "MPNetModel"),
|
("mpnet", "MPNetModel"),
|
||||||
("tapas", "TapasModel"),
|
("tapas", "TapasModel"),
|
||||||
("ibert", "IBertModel"),
|
("ibert", "IBertModel"),
|
||||||
|
("splinter", "SplinterModel"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -356,6 +357,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|||||||
("deberta", "DebertaForQuestionAnswering"),
|
("deberta", "DebertaForQuestionAnswering"),
|
||||||
("deberta-v2", "DebertaV2ForQuestionAnswering"),
|
("deberta-v2", "DebertaV2ForQuestionAnswering"),
|
||||||
("ibert", "IBertForQuestionAnswering"),
|
("ibert", "IBertForQuestionAnswering"),
|
||||||
|
("splinter", "SplinterForQuestionAnswering"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -174,6 +174,7 @@ else:
|
|||||||
("canine", ("CanineTokenizer", None)),
|
("canine", ("CanineTokenizer", None)),
|
||||||
("bertweet", ("BertweetTokenizer", None)),
|
("bertweet", ("BertweetTokenizer", None)),
|
||||||
("bert-japanese", ("BertJapaneseTokenizer", None)),
|
("bert-japanese", ("BertJapaneseTokenizer", None)),
|
||||||
|
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
|
||||||
("byt5", ("ByT5Tokenizer", None)),
|
("byt5", ("ByT5Tokenizer", None)),
|
||||||
(
|
(
|
||||||
"cpm",
|
"cpm",
|
||||||
|
61
src/transformers/models/splinter/__init__.py
Normal file
61
src/transformers/models/splinter/__init__.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
_import_structure = {
|
||||||
|
"configuration_splinter": ["SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SplinterConfig"],
|
||||||
|
"tokenization_splinter": ["SplinterTokenizer"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_tokenizers_available():
|
||||||
|
_import_structure["tokenization_splinter_fast"] = ["SplinterTokenizerFast"]
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
_import_structure["modeling_splinter"] = [
|
||||||
|
"SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"SplinterForQuestionAnswering",
|
||||||
|
"SplinterLayer",
|
||||||
|
"SplinterModel",
|
||||||
|
"SplinterPreTrainedModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_splinter import SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP, SplinterConfig
|
||||||
|
from .tokenization_splinter import SplinterTokenizer
|
||||||
|
|
||||||
|
if is_tokenizers_available():
|
||||||
|
from .tokenization_splinter_fast import SplinterTokenizerFast
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from .modeling_splinter import (
|
||||||
|
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
SplinterForQuestionAnswering,
|
||||||
|
SplinterLayer,
|
||||||
|
SplinterModel,
|
||||||
|
SplinterPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
128
src/transformers/models/splinter/configuration_splinter.py
Normal file
128
src/transformers/models/splinter/configuration_splinter.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 Tel AViv University, AllenAI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Splinter model configuration """
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
SPLINTER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
"tau/splinter-base": "https://huggingface.co/tau/splinter-base/resolve/main/config.json",
|
||||||
|
"tau/splinter-base-qass": "https://huggingface.co/tau/splinter-base-qass/resolve/main/config.json",
|
||||||
|
"tau/splinter-large": "https://huggingface.co/tau/splinter-large/resolve/main/config.json",
|
||||||
|
"tau/splinter-large-qass": "https://huggingface.co/tau/splinter-large-qass/resolve/main/config.json",
|
||||||
|
# See all Splinter models at https://huggingface.co/models?filter=splinter
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SplinterConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a :class:`~transformers.SplinterModel`. It is used to
|
||||||
|
instantiate an Splinter model according to the specified arguments, defining the model architecture. Instantiating
|
||||||
|
a configuration with the defaults will yield a similar configuration to that of the Splinter `tau/splinter-base
|
||||||
|
<https://huggingface.co/tau/splinter-base>`__ architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||||
|
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (:obj:`int`, `optional`, defaults to 30522):
|
||||||
|
Vocabulary size of the Splinter model. Defines the number of different tokens that can be represented by
|
||||||
|
the :obj:`inputs_ids` passed when calling :class:`~transformers.SplinterModel`.
|
||||||
|
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
||||||
|
Dimension of the encoder layers and the pooler layer.
|
||||||
|
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||||
|
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||||
|
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||||
|
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||||
|
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||||
|
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||||
|
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||||
|
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
|
||||||
|
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||||
|
just in case (e.g., 512 or 1024 or 2048).
|
||||||
|
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
|
||||||
|
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.SplinterModel`.
|
||||||
|
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
||||||
|
The epsilon used by the layer normalization layers.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if ``config.is_decoder=True``.
|
||||||
|
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||||
|
question_token_id (:obj:`int`, `optional`, defaults to 104):
|
||||||
|
The id of the ``[QUESTION]`` token.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import SplinterModel, SplinterConfig
|
||||||
|
|
||||||
|
>>> # Initializing a Splinter tau/splinter-base style configuration
|
||||||
|
>>> configuration = SplinterConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the tau/splinter-base style configuration
|
||||||
|
>>> model = SplinterModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
"""
|
||||||
|
model_type = "splinter"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=30522,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
layer_norm_eps=1e-12,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=0,
|
||||||
|
question_token_id=104,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.question_token_id = question_token_id
|
932
src/transformers/models/splinter/modeling_splinter.py
Executable file
932
src/transformers/models/splinter/modeling_splinter.py
Executable file
@ -0,0 +1,932 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 Tel AViv University, AllenAI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PyTorch Splinter model. """
|
||||||
|
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
from ...activations import ACT2FN
|
||||||
|
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
|
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, QuestionAnsweringModelOutput
|
||||||
|
from ...modeling_utils import (
|
||||||
|
PreTrainedModel,
|
||||||
|
apply_chunking_to_forward,
|
||||||
|
find_pruneable_heads_and_indices,
|
||||||
|
prune_linear_layer,
|
||||||
|
)
|
||||||
|
from ...utils import logging
|
||||||
|
from .configuration_splinter import SplinterConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CHECKPOINT_FOR_DOC = "tau/splinter-base"
|
||||||
|
_CONFIG_FOR_DOC = "SplinterConfig"
|
||||||
|
_TOKENIZER_FOR_DOC = "SplinterTokenizer"
|
||||||
|
|
||||||
|
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
|
"tau/splinter-base",
|
||||||
|
"tau/splinter-base-qass",
|
||||||
|
"tau/splinter-large",
|
||||||
|
"tau/splinter-large-qass",
|
||||||
|
# See all Splinter models at https://huggingface.co/models?filter=splinter
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SplinterEmbeddings(nn.Module):
|
||||||
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||||
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||||
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||||
|
|
||||||
|
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||||
|
# any TensorFlow checkpoint file
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||||
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||||
|
):
|
||||||
|
if input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
else:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
|
||||||
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||||
|
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
|
|
||||||
|
embeddings = inputs_embeds + token_type_embeddings
|
||||||
|
if self.position_embedding_type == "absolute":
|
||||||
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
embeddings += position_embeddings
|
||||||
|
embeddings = self.LayerNorm(embeddings)
|
||||||
|
embeddings = self.dropout(embeddings)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Splinter
|
||||||
|
class SplinterSelfAttention(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||||
|
raise ValueError(
|
||||||
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||||
|
f"heads ({config.num_attention_heads})"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
|
|
||||||
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||||
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||||
|
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
|
def transpose_for_scores(self, x):
|
||||||
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
|
x = x.view(*new_x_shape)
|
||||||
|
return x.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
):
|
||||||
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
|
# and values come from an encoder; the attention mask needs to be
|
||||||
|
# such that the encoder's padding tokens are not attended to.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||||
|
else:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
|
||||||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
|
|
||||||
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||||
|
seq_length = hidden_states.size()[1]
|
||||||
|
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||||
|
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||||
|
distance = position_ids_l - position_ids_r
|
||||||
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||||
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||||
|
|
||||||
|
if self.position_embedding_type == "relative_key":
|
||||||
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||||
|
attention_scores = attention_scores + relative_position_scores
|
||||||
|
elif self.position_embedding_type == "relative_key_query":
|
||||||
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||||
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||||
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||||
|
|
||||||
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||||
|
if attention_mask is not None:
|
||||||
|
# Apply the attention mask is (precomputed for all layers in SplinterModel forward() function)
|
||||||
|
attention_scores = attention_scores + attention_mask
|
||||||
|
|
||||||
|
# Normalize the attention scores to probabilities.
|
||||||
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||||
|
|
||||||
|
# This is actually dropping out entire tokens to attend to, which might
|
||||||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
|
attention_probs = self.dropout(attention_probs)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if head_mask is not None:
|
||||||
|
attention_probs = attention_probs * head_mask
|
||||||
|
|
||||||
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||||||
|
|
||||||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Splinter
|
||||||
|
class SplinterSelfOutput(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, input_tensor):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter
|
||||||
|
class SplinterAttention(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.self = SplinterSelfAttention(config)
|
||||||
|
self.output = SplinterSelfOutput(config)
|
||||||
|
self.pruned_heads = set()
|
||||||
|
|
||||||
|
def prune_heads(self, heads):
|
||||||
|
if len(heads) == 0:
|
||||||
|
return
|
||||||
|
heads, index = find_pruneable_heads_and_indices(
|
||||||
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prune linear layers
|
||||||
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||||||
|
self.self.key = prune_linear_layer(self.self.key, index)
|
||||||
|
self.self.value = prune_linear_layer(self.self.value, index)
|
||||||
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||||
|
|
||||||
|
# Update hyper params and store pruned heads
|
||||||
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||||
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
):
|
||||||
|
self_outputs = self.self(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Splinter
|
||||||
|
class SplinterIntermediate(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
|
if isinstance(config.hidden_act, str):
|
||||||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||||
|
else:
|
||||||
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Splinter
|
||||||
|
class SplinterOutput(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, input_tensor):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Splinter
|
||||||
|
class SplinterLayer(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||||
|
self.seq_len_dim = 1
|
||||||
|
self.attention = SplinterAttention(config)
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
self.add_cross_attention = config.add_cross_attention
|
||||||
|
if self.add_cross_attention:
|
||||||
|
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
|
||||||
|
self.crossattention = SplinterAttention(config)
|
||||||
|
self.intermediate = SplinterIntermediate(config)
|
||||||
|
self.output = SplinterOutput(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
):
|
||||||
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
|
self_attention_outputs = self.attention(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
|
)
|
||||||
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
|
assert hasattr(
|
||||||
|
self, "crossattention"
|
||||||
|
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
|
cross_attention_outputs = self.crossattention(
|
||||||
|
attention_output,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
cross_attn_past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
attention_output = cross_attention_outputs[0]
|
||||||
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
|
layer_output = apply_chunking_to_forward(
|
||||||
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||||
|
)
|
||||||
|
outputs = (layer_output,) + outputs
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def feed_forward_chunk(self, attention_output):
|
||||||
|
intermediate_output = self.intermediate(attention_output)
|
||||||
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Splinter
|
||||||
|
class SplinterEncoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
return_dict=True,
|
||||||
|
):
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
for i, layer_module in enumerate(self.layer):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warning(
|
||||||
|
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
|
||||||
|
"`use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(layer_module),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = layer_module(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
layer_head_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_decoder_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attentions,
|
||||||
|
all_cross_attentions,
|
||||||
|
]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SplinterPreTrainedModel(PreTrainedModel):
|
||||||
|
"""
|
||||||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = SplinterConfig
|
||||||
|
base_model_prefix = "splinter"
|
||||||
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||||
|
def _init_weights(self, module):
|
||||||
|
"""Initialize the weights"""
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
SPLINTER_START_DOCSTRING = r"""
|
||||||
|
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
|
||||||
|
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||||
|
behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config (:class:`~transformers.SplinterConfig`): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||||||
|
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
||||||
|
weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SPLINTER_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
|
||||||
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`transformers.SplinterTokenizer`. See
|
||||||
|
:func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for
|
||||||
|
details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
|
||||||
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
||||||
|
1]``:
|
||||||
|
|
||||||
|
- 0 corresponds to a `sentence A` token,
|
||||||
|
- 1 corresponds to a `sentence B` token.
|
||||||
|
|
||||||
|
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||||
|
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
|
||||||
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||||
|
config.max_position_embeddings - 1]``.
|
||||||
|
|
||||||
|
`What are position IDs? <../glossary.html#position-ids>`_
|
||||||
|
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
|
||||||
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||||
|
than the model's internal embedding lookup matrix.
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"The bare Splinter Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
SPLINTER_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class SplinterModel(SplinterPreTrainedModel):
|
||||||
|
"""
|
||||||
|
The model is an encoder (with only self-attention) following the architecture described in `Attention is all you
|
||||||
|
need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion
|
||||||
|
Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.embeddings = SplinterEmbeddings(config)
|
||||||
|
self.encoder = SplinterEncoder(config)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embeddings.word_embeddings = value
|
||||||
|
|
||||||
|
def _prune_heads(self, heads_to_prune):
|
||||||
|
"""
|
||||||
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||||
|
class PreTrainedModel
|
||||||
|
"""
|
||||||
|
for layer, heads in heads_to_prune.items():
|
||||||
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(SPLINTER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
|
the model is configured as a decoder.
|
||||||
|
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||||
|
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if self.config.is_decoder:
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
else:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
|
# past_key_values_length
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||||
|
|
||||||
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_extended_attention_mask = None
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
embedding_output = self.embeddings(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
|
)
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
embedding_output,
|
||||||
|
attention_mask=extended_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
sequence_output = encoder_outputs[0]
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sequence_output,) + encoder_outputs[1:]
|
||||||
|
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=sequence_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SplinterFullyConnectedLayer(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, hidden_act="gelu"):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
self.dense = nn.Linear(self.input_dim, self.output_dim)
|
||||||
|
self.act_fn = ACT2FN[hidden_act]
|
||||||
|
self.LayerNorm = nn.LayerNorm(self.output_dim)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
hidden_states = self.dense(inputs)
|
||||||
|
hidden_states = self.act_fn(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionAwareSpanSelectionHead(nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of Question-Aware Span Selection (QASS) head, described in Splinter's paper:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.query_start_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
|
||||||
|
self.query_end_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
|
||||||
|
self.start_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
|
||||||
|
self.end_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
|
||||||
|
|
||||||
|
self.start_classifier = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||||
|
self.end_classifier = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||||
|
|
||||||
|
def forward(self, inputs, positions):
|
||||||
|
_, _, dim = inputs.size()
|
||||||
|
index = positions.unsqueeze(-1).repeat(1, 1, dim) # [batch_size, num_positions, dim]
|
||||||
|
gathered_reps = torch.gather(inputs, dim=1, index=index) # [batch_size, num_positions, dim]
|
||||||
|
|
||||||
|
query_start_reps = self.query_start_transform(gathered_reps) # [batch_size, num_positions, dim]
|
||||||
|
query_end_reps = self.query_end_transform(gathered_reps) # [batch_size, num_positions, dim]
|
||||||
|
start_reps = self.start_transform(inputs) # [batch_size, seq_length, dim]
|
||||||
|
end_reps = self.end_transform(inputs) # [batch_size, seq_length, dim]
|
||||||
|
|
||||||
|
hidden_states = self.start_classifier(query_start_reps) # [batch_size, num_positions, dim]
|
||||||
|
start_reps = start_reps.permute(0, 2, 1) # [batch_size, dim, seq_length]
|
||||||
|
start_logits = torch.matmul(hidden_states, start_reps)
|
||||||
|
|
||||||
|
hidden_states = self.end_classifier(query_end_reps)
|
||||||
|
end_reps = end_reps.permute(0, 2, 1)
|
||||||
|
end_logits = torch.matmul(hidden_states, end_reps)
|
||||||
|
|
||||||
|
return start_logits, end_logits
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
Splinter Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
||||||
|
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||||||
|
""",
|
||||||
|
SPLINTER_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class SplinterForQuestionAnswering(SplinterPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.splinter = SplinterModel(config)
|
||||||
|
self.splinter_qass = QuestionAwareSpanSelectionHead(config)
|
||||||
|
self.question_token_id = config.question_token_id
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(SPLINTER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
output_type=QuestionAnsweringModelOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
question_positions=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
|
||||||
|
sequence are not taken into account for computing the loss.
|
||||||
|
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
|
||||||
|
sequence are not taken into account for computing the loss.
|
||||||
|
question_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_questions)`, `optional`):
|
||||||
|
The positions of all question tokens. If given, start_logits and end_logits will be of shape
|
||||||
|
:obj:`(batch_size, num_questions, sequence_length)`. If None, the first question token in each sequence in
|
||||||
|
the batch will be the only one for which start_logits and end_logits are calculated and they will be of
|
||||||
|
shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
question_positions_were_none = False
|
||||||
|
if question_positions is None:
|
||||||
|
if input_ids is not None:
|
||||||
|
question_position_for_each_example = torch.argmax(
|
||||||
|
(torch.eq(input_ids, self.question_token_id)).int(), dim=-1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
question_position_for_each_example = torch.zeros(
|
||||||
|
inputs_embeds.size(0), dtype=torch.long, layout=inputs_embeds.layout, device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
question_positions = question_position_for_each_example.unsqueeze(-1)
|
||||||
|
question_positions_were_none = True
|
||||||
|
|
||||||
|
outputs = self.splinter(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
start_logits, end_logits = self.splinter_qass(sequence_output, question_positions)
|
||||||
|
|
||||||
|
if question_positions_were_none:
|
||||||
|
start_logits, end_logits = start_logits.squeeze(1), end_logits.squeeze(1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
start_logits = start_logits + (1 - attention_mask) * -10000.0
|
||||||
|
end_logits = end_logits + (1 - attention_mask) * -10000.0
|
||||||
|
|
||||||
|
total_loss = None
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
# If we are on multi-GPU, split add a dimension
|
||||||
|
if len(start_positions.size()) > 1:
|
||||||
|
start_positions = start_positions.squeeze(-1)
|
||||||
|
if len(end_positions.size()) > 1:
|
||||||
|
end_positions = end_positions.squeeze(-1)
|
||||||
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
|
ignored_index = start_logits.size(1)
|
||||||
|
start_positions.clamp_(0, ignored_index)
|
||||||
|
end_positions.clamp_(0, ignored_index)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (start_logits, end_logits) + outputs[1:]
|
||||||
|
return ((total_loss,) + output) if total_loss is not None else output
|
||||||
|
|
||||||
|
return QuestionAnsweringModelOutput(
|
||||||
|
loss=total_loss,
|
||||||
|
start_logits=start_logits,
|
||||||
|
end_logits=end_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
531
src/transformers/models/splinter/tokenization_splinter.py
Normal file
531
src/transformers/models/splinter/tokenization_splinter.py
Normal file
@ -0,0 +1,531 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 Tel AViv University, AllenAI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tokenization classes for Splinter."""
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import os
|
||||||
|
import unicodedata
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||||
|
|
||||||
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
|
"vocab_file": {
|
||||||
|
"tau/splinter-base": "https://huggingface.co/tau/splinter-base/resolve/main/vocab.txt",
|
||||||
|
"tau/splinter-base-qass": "https://huggingface.co/tau/splinter-base-qass/resolve/main/vocab.txt",
|
||||||
|
"tau/splinter-large": "https://huggingface.co/tau/splinter-large/resolve/main/vocab.txt",
|
||||||
|
"tau/splinter-large-qass": "https://huggingface.co/tau/splinter-large-qass/resolve/main/vocab.txt",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
|
"tau/splinter-base": 512,
|
||||||
|
"tau/splinter-base-qass": 512,
|
||||||
|
"tau/splinter-large": 512,
|
||||||
|
"tau/splinter-large-qass": 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
PRETRAINED_INIT_CONFIGURATION = {
|
||||||
|
"tau/splinter-base": {"do_lower_case": False},
|
||||||
|
"tau/splinter-base-qass": {"do_lower_case": False},
|
||||||
|
"tau/splinter-large": {"do_lower_case": False},
|
||||||
|
"tau/splinter-large-qass": {"do_lower_case": False},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocab(vocab_file):
|
||||||
|
"""Loads a vocabulary file into a dictionary."""
|
||||||
|
vocab = collections.OrderedDict()
|
||||||
|
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||||
|
tokens = reader.readlines()
|
||||||
|
for index, token in enumerate(tokens):
|
||||||
|
token = token.rstrip("\n")
|
||||||
|
vocab[token] = index
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
def whitespace_tokenize(text):
|
||||||
|
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
tokens = text.split()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class SplinterTokenizer(PreTrainedTokenizer):
|
||||||
|
r"""
|
||||||
|
Construct a Splinter tokenizer. Based on WordPiece.
|
||||||
|
|
||||||
|
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
|
||||||
|
Users should refer to this superclass for more information regarding those methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file (:obj:`str`):
|
||||||
|
File containing the vocabulary.
|
||||||
|
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to lowercase the input when tokenizing.
|
||||||
|
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to do basic tokenization before WordPiece.
|
||||||
|
never_split (:obj:`Iterable`, `optional`):
|
||||||
|
Collection of tokens which will never be split during tokenization. Only has an effect when
|
||||||
|
:obj:`do_basic_tokenize=True`
|
||||||
|
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
|
||||||
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
|
token instead.
|
||||||
|
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
|
||||||
|
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||||
|
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||||
|
token of a sequence built with special tokens.
|
||||||
|
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
|
||||||
|
The token used for padding, for example when batching sequences of different lengths.
|
||||||
|
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
|
||||||
|
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
||||||
|
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
||||||
|
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
|
||||||
|
The token used for masking values. This is the token used when training this model with masked language
|
||||||
|
modeling. This is the token which the model will try to predict.
|
||||||
|
question_token (:obj:`str`, `optional`, defaults to :obj:`"[QUESTION]"`):
|
||||||
|
The token used for constructing question representations.
|
||||||
|
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to tokenize Chinese characters.
|
||||||
|
|
||||||
|
This should likely be deactivated for Japanese (see this `issue
|
||||||
|
<https://github.com/huggingface/transformers/issues/328>`__).
|
||||||
|
strip_accents: (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
||||||
|
value for :obj:`lowercase` (as in the original BERT).
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||||
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
do_lower_case=True,
|
||||||
|
do_basic_tokenize=True,
|
||||||
|
never_split=None,
|
||||||
|
unk_token="[UNK]",
|
||||||
|
sep_token="[SEP]",
|
||||||
|
pad_token="[PAD]",
|
||||||
|
cls_token="[CLS]",
|
||||||
|
mask_token="[MASK]",
|
||||||
|
question_token="[QUESTION]",
|
||||||
|
tokenize_chinese_chars=True,
|
||||||
|
strip_accents=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
do_lower_case=do_lower_case,
|
||||||
|
do_basic_tokenize=do_basic_tokenize,
|
||||||
|
never_split=never_split,
|
||||||
|
unk_token=unk_token,
|
||||||
|
sep_token=sep_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
cls_token=cls_token,
|
||||||
|
mask_token=mask_token,
|
||||||
|
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||||
|
strip_accents=strip_accents,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not os.path.isfile(vocab_file):
|
||||||
|
raise ValueError(
|
||||||
|
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
|
||||||
|
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||||
|
)
|
||||||
|
self.vocab = load_vocab(vocab_file)
|
||||||
|
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
||||||
|
self.do_basic_tokenize = do_basic_tokenize
|
||||||
|
if do_basic_tokenize:
|
||||||
|
self.basic_tokenizer = BasicTokenizer(
|
||||||
|
do_lower_case=do_lower_case,
|
||||||
|
never_split=never_split,
|
||||||
|
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||||
|
strip_accents=strip_accents,
|
||||||
|
)
|
||||||
|
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
||||||
|
self.question_token = question_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def question_token_id(self):
|
||||||
|
"""
|
||||||
|
:obj:`Optional[int]`: Id of the question token in the vocabulary, used to condition the answer on a question
|
||||||
|
representation.
|
||||||
|
"""
|
||||||
|
return self.convert_tokens_to_ids(self.question_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def do_lower_case(self):
|
||||||
|
return self.basic_tokenizer.do_lower_case
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return len(self.vocab)
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
return dict(self.vocab, **self.added_tokens_encoder)
|
||||||
|
|
||||||
|
def _tokenize(self, text):
|
||||||
|
split_tokens = []
|
||||||
|
if self.do_basic_tokenize:
|
||||||
|
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||||
|
|
||||||
|
# If the token is part of the never_split set
|
||||||
|
if token in self.basic_tokenizer.never_split:
|
||||||
|
split_tokens.append(token)
|
||||||
|
else:
|
||||||
|
split_tokens += self.wordpiece_tokenizer.tokenize(token)
|
||||||
|
else:
|
||||||
|
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||||
|
return split_tokens
|
||||||
|
|
||||||
|
def _convert_token_to_id(self, token):
|
||||||
|
"""Converts a token (str) in an id using the vocab."""
|
||||||
|
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index):
|
||||||
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||||
|
return self.ids_to_tokens.get(index, self.unk_token)
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens):
|
||||||
|
"""Converts a sequence of tokens (string) in a single string."""
|
||||||
|
out_string = " ".join(tokens).replace(" ##", "").strip()
|
||||||
|
return out_string
|
||||||
|
|
||||||
|
def build_inputs_with_special_tokens(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Build model inputs from a pair of sequence for question answering tasks by concatenating and adding special
|
||||||
|
tokens. A Splinter sequence has the following format:
|
||||||
|
|
||||||
|
- single sequence: ``[CLS] X [SEP]``
|
||||||
|
- pair of sequences for question answering: ``[CLS] question_tokens [QUESTION] . [SEP] context_tokens [SEP]``
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
The question token IDs if pad_on_right, else context tokens IDs
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||||
|
The context token IDs if pad_on_right, else question token IDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||||
|
"""
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
|
cls = [self.cls_token_id]
|
||||||
|
sep = [self.sep_token_id]
|
||||||
|
question_suffix = [self.question_token_id] + [self.convert_tokens_to_ids(".")]
|
||||||
|
if self.padding_side == "right":
|
||||||
|
# Input is question-then-context
|
||||||
|
return cls + token_ids_0 + question_suffix + sep + token_ids_1 + sep
|
||||||
|
else:
|
||||||
|
# Input is context-then-question
|
||||||
|
return cls + token_ids_0 + sep + token_ids_1 + question_suffix + sep
|
||||||
|
|
||||||
|
def get_special_tokens_mask(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
|
special tokens using the tokenizer ``prepare_for_model`` method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
List of IDs.
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not the token list is already formatted with special tokens for the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if already_has_special_tokens:
|
||||||
|
return super().get_special_tokens_mask(
|
||||||
|
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if token_ids_1 is not None:
|
||||||
|
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||||
|
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||||
|
|
||||||
|
def create_token_type_ids_from_sequences(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Create the token type IDs corresponding to the sequences passed. `What are token type IDs?
|
||||||
|
<../glossary.html#token-type-ids>`__
|
||||||
|
|
||||||
|
Should be overridden in a subclass if the model has a special way of building those.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`): The first tokenized sequence.
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`): The second tokenized sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: The token type ids.
|
||||||
|
"""
|
||||||
|
sep = [self.sep_token_id]
|
||||||
|
cls = [self.cls_token_id]
|
||||||
|
question_suffix = [self.question_token_id] + [self.convert_tokens_to_ids(".")]
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return len(cls + token_ids_0 + sep) * [0]
|
||||||
|
|
||||||
|
if self.padding_side == "right":
|
||||||
|
# Input is question-then-context
|
||||||
|
return len(cls + token_ids_0 + question_suffix + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||||
|
else:
|
||||||
|
# Input is context-then-question
|
||||||
|
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + question_suffix + sep) * [1]
|
||||||
|
|
||||||
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
|
index = 0
|
||||||
|
if os.path.isdir(save_directory):
|
||||||
|
vocab_file = os.path.join(
|
||||||
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||||
|
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||||
|
if index != token_index:
|
||||||
|
logger.warning(
|
||||||
|
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
||||||
|
" Please check that the vocabulary is not corrupted!"
|
||||||
|
)
|
||||||
|
index = token_index
|
||||||
|
writer.write(token + "\n")
|
||||||
|
index += 1
|
||||||
|
return (vocab_file,)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTokenizer(object):
|
||||||
|
"""
|
||||||
|
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to lowercase the input when tokenizing.
|
||||||
|
never_split (:obj:`Iterable`, `optional`):
|
||||||
|
Collection of tokens which will never be split during tokenization. Only has an effect when
|
||||||
|
:obj:`do_basic_tokenize=True`
|
||||||
|
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to tokenize Chinese characters.
|
||||||
|
|
||||||
|
This should likely be deactivated for Japanese (see this `issue
|
||||||
|
<https://github.com/huggingface/transformers/issues/328>`__).
|
||||||
|
strip_accents: (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
||||||
|
value for :obj:`lowercase` (as in the original BERT).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
|
||||||
|
if never_split is None:
|
||||||
|
never_split = []
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
self.never_split = set(never_split)
|
||||||
|
self.tokenize_chinese_chars = tokenize_chinese_chars
|
||||||
|
self.strip_accents = strip_accents
|
||||||
|
|
||||||
|
def tokenize(self, text, never_split=None):
|
||||||
|
"""
|
||||||
|
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
|
||||||
|
WordPieceTokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**never_split**: (`optional`) list of str
|
||||||
|
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
||||||
|
:func:`PreTrainedTokenizer.tokenize`) List of token not to split.
|
||||||
|
"""
|
||||||
|
# union() returns a new set by concatenating the two sets.
|
||||||
|
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
|
||||||
|
text = self._clean_text(text)
|
||||||
|
|
||||||
|
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||||
|
# models. This is also applied to the English models now, but it doesn't
|
||||||
|
# matter since the English models were not trained on any Chinese data
|
||||||
|
# and generally don't have any Chinese data in them (there are Chinese
|
||||||
|
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||||
|
# words in the English Wikipedia.).
|
||||||
|
if self.tokenize_chinese_chars:
|
||||||
|
text = self._tokenize_chinese_chars(text)
|
||||||
|
orig_tokens = whitespace_tokenize(text)
|
||||||
|
split_tokens = []
|
||||||
|
for token in orig_tokens:
|
||||||
|
if token not in never_split:
|
||||||
|
if self.do_lower_case:
|
||||||
|
token = token.lower()
|
||||||
|
if self.strip_accents is not False:
|
||||||
|
token = self._run_strip_accents(token)
|
||||||
|
elif self.strip_accents:
|
||||||
|
token = self._run_strip_accents(token)
|
||||||
|
split_tokens.extend(self._run_split_on_punc(token, never_split))
|
||||||
|
|
||||||
|
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||||
|
return output_tokens
|
||||||
|
|
||||||
|
def _run_strip_accents(self, text):
|
||||||
|
"""Strips accents from a piece of text."""
|
||||||
|
text = unicodedata.normalize("NFD", text)
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat == "Mn":
|
||||||
|
continue
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _run_split_on_punc(self, text, never_split=None):
|
||||||
|
"""Splits punctuation on a piece of text."""
|
||||||
|
if never_split is not None and text in never_split:
|
||||||
|
return [text]
|
||||||
|
chars = list(text)
|
||||||
|
i = 0
|
||||||
|
start_new_word = True
|
||||||
|
output = []
|
||||||
|
while i < len(chars):
|
||||||
|
char = chars[i]
|
||||||
|
if _is_punctuation(char):
|
||||||
|
output.append([char])
|
||||||
|
start_new_word = True
|
||||||
|
else:
|
||||||
|
if start_new_word:
|
||||||
|
output.append([])
|
||||||
|
start_new_word = False
|
||||||
|
output[-1].append(char)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return ["".join(x) for x in output]
|
||||||
|
|
||||||
|
def _tokenize_chinese_chars(self, text):
|
||||||
|
"""Adds whitespace around any CJK character."""
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cp = ord(char)
|
||||||
|
if self._is_chinese_char(cp):
|
||||||
|
output.append(" ")
|
||||||
|
output.append(char)
|
||||||
|
output.append(" ")
|
||||||
|
else:
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _is_chinese_char(self, cp):
|
||||||
|
"""Checks whether CP is the codepoint of a CJK character."""
|
||||||
|
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||||
|
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||||
|
#
|
||||||
|
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||||
|
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||||
|
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||||
|
# space-separated words, so they are not treated specially and handled
|
||||||
|
# like the all of the other languages.
|
||||||
|
if (
|
||||||
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||||
|
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
||||||
|
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
||||||
|
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
||||||
|
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
||||||
|
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
||||||
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||||
|
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
||||||
|
): #
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _clean_text(self, text):
|
||||||
|
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cp = ord(char)
|
||||||
|
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
||||||
|
continue
|
||||||
|
if _is_whitespace(char):
|
||||||
|
output.append(" ")
|
||||||
|
else:
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
|
class WordpieceTokenizer(object):
|
||||||
|
"""Runs WordPiece tokenization."""
|
||||||
|
|
||||||
|
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.unk_token = unk_token
|
||||||
|
self.max_input_chars_per_word = max_input_chars_per_word
|
||||||
|
|
||||||
|
def tokenize(self, text):
|
||||||
|
"""
|
||||||
|
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
|
||||||
|
tokenization using the given vocabulary.
|
||||||
|
|
||||||
|
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: A single token or whitespace separated tokens. This should have
|
||||||
|
already been passed through `BasicTokenizer`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of wordpiece tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_tokens = []
|
||||||
|
for token in whitespace_tokenize(text):
|
||||||
|
chars = list(token)
|
||||||
|
if len(chars) > self.max_input_chars_per_word:
|
||||||
|
output_tokens.append(self.unk_token)
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_bad = False
|
||||||
|
start = 0
|
||||||
|
sub_tokens = []
|
||||||
|
while start < len(chars):
|
||||||
|
end = len(chars)
|
||||||
|
cur_substr = None
|
||||||
|
while start < end:
|
||||||
|
substr = "".join(chars[start:end])
|
||||||
|
if start > 0:
|
||||||
|
substr = "##" + substr
|
||||||
|
if substr in self.vocab:
|
||||||
|
cur_substr = substr
|
||||||
|
break
|
||||||
|
end -= 1
|
||||||
|
if cur_substr is None:
|
||||||
|
is_bad = True
|
||||||
|
break
|
||||||
|
sub_tokens.append(cur_substr)
|
||||||
|
start = end
|
||||||
|
|
||||||
|
if is_bad:
|
||||||
|
output_tokens.append(self.unk_token)
|
||||||
|
else:
|
||||||
|
output_tokens.extend(sub_tokens)
|
||||||
|
return output_tokens
|
216
src/transformers/models/splinter/tokenization_splinter_fast.py
Normal file
216
src/transformers/models/splinter/tokenization_splinter_fast.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 Tel AViv University, AllenAI and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Fast Tokenization classes for Splinter."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from tokenizers import normalizers
|
||||||
|
|
||||||
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
|
from ...utils import logging
|
||||||
|
from .tokenization_splinter import SplinterTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||||
|
|
||||||
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
|
"vocab_file": {
|
||||||
|
"tau/splinter-base": "https://huggingface.co/tau/splinter-base/resolve/main/vocab.txt",
|
||||||
|
"tau/splinter-base-qass": "https://huggingface.co/tau/splinter-base-qass/resolve/main/vocab.txt",
|
||||||
|
"tau/splinter-large": "https://huggingface.co/tau/splinter-large/resolve/main/vocab.txt",
|
||||||
|
"tau/splinter-large-qass": "https://huggingface.co/tau/splinter-large-qass/resolve/main/vocab.txt",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
|
"tau/splinter-base": 512,
|
||||||
|
"tau/splinter-base-qass": 512,
|
||||||
|
"tau/splinter-large": 512,
|
||||||
|
"tau/splinter-large-qass": 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
PRETRAINED_INIT_CONFIGURATION = {
|
||||||
|
"tau/splinter-base": {"do_lower_case": False},
|
||||||
|
"tau/splinter-base-qass": {"do_lower_case": False},
|
||||||
|
"tau/splinter-large": {"do_lower_case": False},
|
||||||
|
"tau/splinter-large-qass": {"do_lower_case": False},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SplinterTokenizerFast(PreTrainedTokenizerFast):
|
||||||
|
r"""
|
||||||
|
Construct a "fast" Splinter tokenizer (backed by HuggingFace's `tokenizers` library). Based on WordPiece.
|
||||||
|
|
||||||
|
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
|
||||||
|
methods. Users should refer to this superclass for more information regarding those methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file (:obj:`str`):
|
||||||
|
File containing the vocabulary.
|
||||||
|
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to lowercase the input when tokenizing.
|
||||||
|
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
|
||||||
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
|
token instead.
|
||||||
|
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
|
||||||
|
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||||
|
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||||
|
token of a sequence built with special tokens.
|
||||||
|
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
|
||||||
|
The token used for padding, for example when batching sequences of different lengths.
|
||||||
|
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
|
||||||
|
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
||||||
|
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
||||||
|
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
|
||||||
|
The token used for masking values. This is the token used when training this model with masked language
|
||||||
|
modeling. This is the token which the model will try to predict.
|
||||||
|
question_token (:obj:`str`, `optional`, defaults to :obj:`"[QUESTION]"`):
|
||||||
|
The token used for constructing question representations.
|
||||||
|
clean_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to clean the text before tokenization by removing any control characters and replacing all
|
||||||
|
whitespaces by the classic one.
|
||||||
|
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see `this
|
||||||
|
issue <https://github.com/huggingface/transformers/issues/328>`__).
|
||||||
|
strip_accents: (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
||||||
|
value for :obj:`lowercase` (as in the original BERT).
|
||||||
|
wordpieces_prefix: (:obj:`str`, `optional`, defaults to :obj:`"##"`):
|
||||||
|
The prefix for subwords.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||||
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
slow_tokenizer_class = SplinterTokenizer
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file=None,
|
||||||
|
tokenizer_file=None,
|
||||||
|
do_lower_case=True,
|
||||||
|
unk_token="[UNK]",
|
||||||
|
sep_token="[SEP]",
|
||||||
|
pad_token="[PAD]",
|
||||||
|
cls_token="[CLS]",
|
||||||
|
mask_token="[MASK]",
|
||||||
|
question_token="[QUESTION]",
|
||||||
|
tokenize_chinese_chars=True,
|
||||||
|
strip_accents=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
vocab_file,
|
||||||
|
tokenizer_file=tokenizer_file,
|
||||||
|
do_lower_case=do_lower_case,
|
||||||
|
unk_token=unk_token,
|
||||||
|
sep_token=sep_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
cls_token=cls_token,
|
||||||
|
mask_token=mask_token,
|
||||||
|
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||||
|
strip_accents=strip_accents,
|
||||||
|
additional_special_tokens=(question_token,),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
||||||
|
if (
|
||||||
|
pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
|
||||||
|
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
|
||||||
|
):
|
||||||
|
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
|
||||||
|
pre_tok_state["lowercase"] = do_lower_case
|
||||||
|
pre_tok_state["strip_accents"] = strip_accents
|
||||||
|
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
|
||||||
|
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
|
||||||
|
@property
|
||||||
|
def question_token_id(self):
|
||||||
|
"""
|
||||||
|
:obj:`Optional[int]`: Id of the question token in the vocabulary, used to condition the answer on a question
|
||||||
|
representation.
|
||||||
|
"""
|
||||||
|
return self.convert_tokens_to_ids(self.question_token)
|
||||||
|
|
||||||
|
def build_inputs_with_special_tokens(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Build model inputs from a pair of sequence for question answering tasks by concatenating and adding special
|
||||||
|
tokens. A Splinter sequence has the following format:
|
||||||
|
|
||||||
|
- single sequence: ``[CLS] X [SEP]``
|
||||||
|
- pair of sequences for question answering: ``[CLS] question_tokens [QUESTION] . [SEP] context_tokens [SEP]``
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
The question token IDs if pad_on_right, else context tokens IDs
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||||
|
The context token IDs if pad_on_right, else question token IDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||||
|
"""
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
|
|
||||||
|
cls = [self.cls_token_id]
|
||||||
|
sep = [self.sep_token_id]
|
||||||
|
question_suffix = [self.question_token_id] + [self.convert_tokens_to_ids(".")]
|
||||||
|
if self.padding_side == "right":
|
||||||
|
# Input is question-then-context
|
||||||
|
return cls + token_ids_0 + question_suffix + sep + token_ids_1 + sep
|
||||||
|
else:
|
||||||
|
# Input is context-then-question
|
||||||
|
return cls + token_ids_0 + sep + token_ids_1 + question_suffix + sep
|
||||||
|
|
||||||
|
def create_token_type_ids_from_sequences(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Create the token type IDs corresponding to the sequences passed. `What are token type IDs?
|
||||||
|
<../glossary.html#token-type-ids>`__
|
||||||
|
|
||||||
|
Should be overridden in a subclass if the model has a special way of building those.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`): The first tokenized sequence.
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`): The second tokenized sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: The token type ids.
|
||||||
|
"""
|
||||||
|
sep = [self.sep_token_id]
|
||||||
|
cls = [self.cls_token_id]
|
||||||
|
question_suffix = [self.question_token_id] + [self.convert_tokens_to_ids(".")]
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return len(cls + token_ids_0 + sep) * [0]
|
||||||
|
|
||||||
|
if self.padding_side == "right":
|
||||||
|
# Input is question-then-context
|
||||||
|
return len(cls + token_ids_0 + question_suffix + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||||
|
else:
|
||||||
|
# Input is context-then-question
|
||||||
|
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + question_suffix + sep) * [1]
|
||||||
|
|
||||||
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
|
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||||
|
return tuple(files)
|
@ -3121,6 +3121,41 @@ class Speech2TextPreTrainedModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class SplinterForQuestionAnswering:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SplinterLayer:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SplinterModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SplinterPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -299,6 +299,15 @@ class RoFormerTokenizerFast:
|
|||||||
requires_backends(cls, ["tokenizers"])
|
requires_backends(cls, ["tokenizers"])
|
||||||
|
|
||||||
|
|
||||||
|
class SplinterTokenizerFast:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tokenizers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["tokenizers"])
|
||||||
|
|
||||||
|
|
||||||
class SqueezeBertTokenizerFast:
|
class SqueezeBertTokenizerFast:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["tokenizers"])
|
requires_backends(self, ["tokenizers"])
|
||||||
|
219
tests/test_modeling_splinter.py
Normal file
219
tests/test_modeling_splinter.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Testing suite for the PyTorch Splinter model. """
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import is_torch_available
|
||||||
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
|
from .test_configuration_common import ConfigTester
|
||||||
|
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import SplinterConfig, SplinterForQuestionAnswering, SplinterModel
|
||||||
|
from transformers.models.splinter.modeling_splinter import SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
|
|
||||||
|
class SplinterModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_input_mask=True,
|
||||||
|
use_token_type_ids=True,
|
||||||
|
use_labels=True,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=37,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
type_sequence_label_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
num_labels=3,
|
||||||
|
num_choices=4,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_input_mask = use_input_mask
|
||||||
|
self.use_token_type_ids = use_token_type_ids
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.num_choices = num_choices
|
||||||
|
self.scope = scope
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
input_mask = None
|
||||||
|
if self.use_input_mask:
|
||||||
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
|
token_type_ids = None
|
||||||
|
if self.use_token_type_ids:
|
||||||
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||||
|
|
||||||
|
sequence_labels = None
|
||||||
|
token_labels = None
|
||||||
|
choice_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||||
|
|
||||||
|
config = SplinterConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||||
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
type_vocab_size=self.type_vocab_size,
|
||||||
|
is_decoder=False,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
)
|
||||||
|
|
||||||
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
def create_and_check_model(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = SplinterModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||||
|
result = model(input_ids, token_type_ids=token_type_ids)
|
||||||
|
result = model(input_ids)
|
||||||
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_for_question_answering(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = SplinterForQuestionAnswering(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
start_positions=sequence_labels,
|
||||||
|
end_positions=sequence_labels,
|
||||||
|
)
|
||||||
|
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||||
|
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class SplinterModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
SplinterModel,
|
||||||
|
SplinterForQuestionAnswering,
|
||||||
|
)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = SplinterModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=SplinterConfig, hidden_size=37)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_model_various_embeddings(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||||
|
config_and_inputs[0].position_embedding_type = type
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_question_answering(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
for model_name in SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
model = SplinterModel.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class SplinterModelIntegrationTest(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_splinter_question_answering(self):
|
||||||
|
model = SplinterForQuestionAnswering.from_pretrained("tau/splinter-base-qass")
|
||||||
|
|
||||||
|
# Input: "[CLS] Brad was born in [QUESTION] . He returned to the United Kingdom later . [SEP]"
|
||||||
|
# Output should be the span "the United Kingdom"
|
||||||
|
input_ids = torch.tensor(
|
||||||
|
[[101, 7796, 1108, 1255, 1107, 104, 119, 1124, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102]]
|
||||||
|
)
|
||||||
|
output = model(input_ids)
|
||||||
|
|
||||||
|
expected_shape = torch.Size((1, 16))
|
||||||
|
self.assertEqual(output.start_logits.shape, expected_shape)
|
||||||
|
self.assertEqual(output.end_logits.shape, expected_shape)
|
||||||
|
|
||||||
|
self.assertEqual(torch.argmax(output.start_logits), 10)
|
||||||
|
self.assertEqual(torch.argmax(output.end_logits), 12)
|
Loading…
Reference in New Issue
Block a user