mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add new model RoFormer (use rotary position embedding ) (#11684)
* add roformer * Update docs/source/model_doc/roformer.rst Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update docs/source/model_doc/roformer.rst Co-authored-by: Suraj Patil <surajp815@gmail.com> * update * add TFRoFormerSinusoidalPositionalEmbedding and fix TFMarianSinusoidalPositionalEmbedding * update docs * make style and make quality * roback * unchanged * rm copies from , this is a error in TFMarianSinusoidalPositionalEmbedding * update Copyright year * move # Add modeling imports here to the correct position * max_position_embeddings can be set to 1536 * # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->RoFormer * # Copied from transformers.models.bert.modeling_bert.BertLayer.__init__ with Bert->RoFormer * update tokenization_roformer * make style * add staticmethod apply_rotary_position_embeddings * add TF staticmethod apply_rotary_position_embeddings * update torch apply_rotary_position_embeddings * fix tf apply_rotary_position_embeddings error * make style * add pytorch RoFormerSelfAttentionRotaryPositionEmbeddingTest * add TF rotary_position_embeddings test * update test_modeling_rofomer * Update docs/source/model_doc/roformer.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/roformer/convert_roformer_original_tf_checkpoint_to_pytorch.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/roformer/modeling_roformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/roformer/modeling_roformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/roformer/modeling_tf_roformer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * refact roformer tokenizer * add RoFormerTokenizerFast * add RoFormerTokenizationTest * add require_jieba * update Copyright * update tokenizer & add copy from * add option rotary_value * use rust jieba * use rjieba * use rust jieba * fix test_alignement_methods * slice normalized_string is too slow * add config.embedding_size when embedding_size!=hidden_size * fix pickle tokenizer * Update docs/source/model_doc/roformer.rst Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * make style and make quality Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
075fdab4fe
commit
206f06f2dd
@ -243,6 +243,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[ProphetNet](https://huggingface.co/transformers/model_doc/prophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
1. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
||||
1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
1. **[RoFormer](https://huggingface.co/transformers/model_doc/roformer.html)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
1. **[SpeechToTextTransformer](https://huggingface.co/transformers/model_doc/speech_to_text.html)** (from Facebook), released together with the paper [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
|
||||
1. **[SqueezeBert](https://huggingface.co/transformers/model_doc/squeezebert.html)** released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
|
||||
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
|
@ -231,41 +231,44 @@ Supported models
|
||||
45. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
|
||||
Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar
|
||||
Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
46. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
|
||||
46. :doc:`RoFormer <model_doc/roformer>` (from ZhuiyiTechnology), released together with the paper a `RoFormer:
|
||||
Enhanced Transformer with Rotary Position Embedding <https://arxiv.org/pdf/2104.09864v1.pdf>`__ by Jianlin Su and
|
||||
Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
47. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
|
||||
`fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun
|
||||
Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
|
||||
47. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||
48. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
|
||||
Krishna, and Kurt W. Keutzer.
|
||||
48. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
49. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
|
||||
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
49. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
50. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
|
||||
Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
50. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
51. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
||||
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
51. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
|
||||
52. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
|
||||
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
|
||||
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
|
||||
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
52. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
53. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
|
||||
Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
53. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
54. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
||||
54. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
55. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
||||
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
55. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
56. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
||||
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
||||
Zettlemoyer and Veselin Stoyanov.
|
||||
56. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
57. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
57. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
58. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
|
||||
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
|
||||
@ -369,6 +372,8 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| RoFormer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Speech2Text | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
@ -520,6 +525,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
model_doc/reformer
|
||||
model_doc/retribert
|
||||
model_doc/roberta
|
||||
model_doc/roformer
|
||||
model_doc/speech_to_text
|
||||
model_doc/squeezebert
|
||||
model_doc/t5
|
||||
|
161
docs/source/model_doc/roformer.rst
Normal file
161
docs/source/model_doc/roformer.rst
Normal file
@ -0,0 +1,161 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
RoFormer
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The RoFormer model was proposed in `RoFormer: Enhanced Transformer with Rotary Position Embedding
|
||||
<https://arxiv.org/pdf/2104.09864v1.pdf>`__ by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Position encoding in transformer architecture provides supervision for dependency modeling between elements at
|
||||
different positions in the sequence. We investigate various methods to encode positional information in
|
||||
transformer-based language models and propose a novel implementation named Rotary Position Embedding(RoPE). The
|
||||
proposed RoPE encodes absolute positional information with rotation matrix and naturally incorporates explicit relative
|
||||
position dependency in self-attention formulation. Notably, RoPE comes with valuable properties such as flexibility of
|
||||
being expand to any sequence lengths, decaying inter-token dependency with increasing relative distances, and
|
||||
capability of equipping the linear self-attention with relative position encoding. As a result, the enhanced
|
||||
transformer with rotary position embedding, or RoFormer, achieves superior performance in tasks with long texts. We
|
||||
release the theoretical analysis along with some preliminary experiment results on Chinese data. The undergoing
|
||||
experiment for English benchmark will soon be updated.*
|
||||
|
||||
Tips:
|
||||
|
||||
- RoFormer is a BERT-like autoencoding model with rotary position embeddings. Rotary position embeddings have shown
|
||||
improved performance on classification tasks with long texts.
|
||||
|
||||
|
||||
This model was contributed by `junnyu <https://huggingface.co/junnyu>`__. The original code can be found `here
|
||||
<https://github.com/ZhuiyiTechnology/roformer>`__.
|
||||
|
||||
RoFormerConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RoFormerConfig
|
||||
:members:
|
||||
|
||||
|
||||
RoFormerTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RoFormerTokenizer
|
||||
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||
create_token_type_ids_from_sequences, save_vocabulary
|
||||
|
||||
|
||||
RobertaTokenizerFast
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RoFormerTokenizerFast
|
||||
:members: build_inputs_with_special_tokens
|
||||
|
||||
|
||||
RoFormerModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RoFormerModel
|
||||
:members: forward
|
||||
|
||||
|
||||
RoFormerForCausalLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RoFormerForCausalLM
|
||||
:members: forward
|
||||
|
||||
|
||||
RoFormerForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RoFormerForMaskedLM
|
||||
:members: forward
|
||||
|
||||
|
||||
RoFormerForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RoFormerForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
RoFormerForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RoFormerForMultipleChoice
|
||||
:members: forward
|
||||
|
||||
|
||||
RoFormerForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RoFormerForTokenClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
RoFormerForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.RoFormerForQuestionAnswering
|
||||
:members: forward
|
||||
|
||||
|
||||
TFRoFormerModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFRoFormerModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFRoFormerForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFRoFormerForMaskedLM
|
||||
:members: call
|
||||
|
||||
|
||||
TFRoFormerForCausalLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFRoFormerForCausalLM
|
||||
:members: call
|
||||
|
||||
|
||||
TFRoFormerForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFRoFormerForSequenceClassification
|
||||
:members: call
|
||||
|
||||
|
||||
TFRoFormerForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFRoFormerForMultipleChoice
|
||||
:members: call
|
||||
|
||||
|
||||
TFRoFormerForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFRoFormerForTokenClassification
|
||||
:members: call
|
||||
|
||||
|
||||
TFRoFormerForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFRoFormerForQuestionAnswering
|
||||
:members: call
|
@ -218,6 +218,7 @@ _import_structure = {
|
||||
"models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"],
|
||||
"models.retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig", "RetriBertTokenizer"],
|
||||
"models.roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaTokenizer"],
|
||||
"models.roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig", "RoFormerTokenizer"],
|
||||
"models.speech_to_text": [
|
||||
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"Speech2TextConfig",
|
||||
@ -322,6 +323,7 @@ else:
|
||||
# tokenizers-backed objects
|
||||
if is_tokenizers_available():
|
||||
# Fast tokenizers
|
||||
_import_structure["models.roformer"].append("RoFormerTokenizerFast")
|
||||
_import_structure["models.clip"].append("CLIPTokenizerFast")
|
||||
_import_structure["models.convbert"].append("ConvBertTokenizerFast")
|
||||
_import_structure["models.albert"].append("AlbertTokenizerFast")
|
||||
@ -927,6 +929,21 @@ if is_torch_available():
|
||||
"RobertaModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.roformer"].extend(
|
||||
[
|
||||
"ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"RoFormerForCausalLM",
|
||||
"RoFormerForMaskedLM",
|
||||
"RoFormerForMultipleChoice",
|
||||
"RoFormerForQuestionAnswering",
|
||||
"RoFormerForSequenceClassification",
|
||||
"RoFormerForTokenClassification",
|
||||
"RoFormerLayer",
|
||||
"RoFormerModel",
|
||||
"RoFormerPreTrainedModel",
|
||||
"load_tf_weights_in_roformer",
|
||||
]
|
||||
)
|
||||
_import_structure["models.speech_to_text"].extend(
|
||||
[
|
||||
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -1339,6 +1356,20 @@ if is_tf_available():
|
||||
"TFRobertaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.roformer"].extend(
|
||||
[
|
||||
"TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFRoFormerForCausalLM",
|
||||
"TFRoFormerForMaskedLM",
|
||||
"TFRoFormerForMultipleChoice",
|
||||
"TFRoFormerForQuestionAnswering",
|
||||
"TFRoFormerForSequenceClassification",
|
||||
"TFRoFormerForTokenClassification",
|
||||
"TFRoFormerLayer",
|
||||
"TFRoFormerModel",
|
||||
"TFRoFormerPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.t5"].extend(
|
||||
[
|
||||
"TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -1641,6 +1672,7 @@ if TYPE_CHECKING:
|
||||
from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
|
||||
from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer
|
||||
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
|
||||
from .models.roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerTokenizer
|
||||
from .models.speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
|
||||
from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer
|
||||
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||
@ -1767,6 +1799,7 @@ if TYPE_CHECKING:
|
||||
from .models.reformer import ReformerTokenizerFast
|
||||
from .models.retribert import RetriBertTokenizerFast
|
||||
from .models.roberta import RobertaTokenizerFast
|
||||
from .models.roformer import RoFormerTokenizerFast
|
||||
from .models.squeezebert import SqueezeBertTokenizerFast
|
||||
from .models.t5 import T5TokenizerFast
|
||||
from .models.xlm_roberta import XLMRobertaTokenizerFast
|
||||
@ -2232,6 +2265,19 @@ if TYPE_CHECKING:
|
||||
RobertaForTokenClassification,
|
||||
RobertaModel,
|
||||
)
|
||||
from .models.roformer import (
|
||||
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
RoFormerForCausalLM,
|
||||
RoFormerForMaskedLM,
|
||||
RoFormerForMultipleChoice,
|
||||
RoFormerForQuestionAnswering,
|
||||
RoFormerForSequenceClassification,
|
||||
RoFormerForTokenClassification,
|
||||
RoFormerLayer,
|
||||
RoFormerModel,
|
||||
RoFormerPreTrainedModel,
|
||||
load_tf_weights_in_roformer,
|
||||
)
|
||||
from .models.speech_to_text import (
|
||||
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Speech2TextForConditionalGeneration,
|
||||
@ -2575,6 +2621,18 @@ if TYPE_CHECKING:
|
||||
TFRobertaModel,
|
||||
TFRobertaPreTrainedModel,
|
||||
)
|
||||
from .models.roformer import (
|
||||
TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFRoFormerForCausalLM,
|
||||
TFRoFormerForMaskedLM,
|
||||
TFRoFormerForMultipleChoice,
|
||||
TFRoFormerForQuestionAnswering,
|
||||
TFRoFormerForSequenceClassification,
|
||||
TFRoFormerForTokenClassification,
|
||||
TFRoFormerLayer,
|
||||
TFRoFormerModel,
|
||||
TFRoFormerPreTrainedModel,
|
||||
)
|
||||
from .models.t5 import (
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFT5EncoderModel,
|
||||
|
@ -25,6 +25,7 @@ from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers,
|
||||
from tokenizers.models import BPE, Unigram, WordPiece
|
||||
|
||||
from .file_utils import requires_backends
|
||||
from .models.roformer.tokenization_utils import JiebaPreTokenizer
|
||||
|
||||
|
||||
class SentencePieceExtractor:
|
||||
@ -296,6 +297,43 @@ class RobertaConverter(Converter):
|
||||
return tokenizer
|
||||
|
||||
|
||||
class RoFormerConverter(Converter):
|
||||
def converted(self) -> Tokenizer:
|
||||
vocab = self.original_tokenizer.vocab
|
||||
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
||||
|
||||
strip_accents = False
|
||||
do_lower_case = False
|
||||
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
||||
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
||||
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
||||
|
||||
tokenizer.normalizer = normalizers.BertNormalizer(
|
||||
clean_text=True,
|
||||
handle_chinese_chars=False,
|
||||
strip_accents=strip_accents,
|
||||
lowercase=do_lower_case,
|
||||
)
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
|
||||
|
||||
cls = str(self.original_tokenizer.cls_token)
|
||||
sep = str(self.original_tokenizer.sep_token)
|
||||
cls_token_id = self.original_tokenizer.cls_token_id
|
||||
sep_token_id = self.original_tokenizer.sep_token_id
|
||||
|
||||
tokenizer.post_processor = processors.TemplateProcessing(
|
||||
single=f"{cls}:0 $A:0 {sep}:0",
|
||||
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
|
||||
special_tokens=[
|
||||
(cls, cls_token_id),
|
||||
(sep, sep_token_id),
|
||||
],
|
||||
)
|
||||
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class DebertaConverter(Converter):
|
||||
def converted(self) -> Tokenizer:
|
||||
ot = self.original_tokenizer
|
||||
@ -755,6 +793,7 @@ SLOW_TO_FAST_CONVERTERS = {
|
||||
"ReformerTokenizer": ReformerConverter,
|
||||
"RetriBertTokenizer": BertConverter,
|
||||
"RobertaTokenizer": RobertaConverter,
|
||||
"RoFormerTokenizer": RoFormerConverter,
|
||||
"SqueezeBertTokenizer": BertConverter,
|
||||
"T5Tokenizer": T5Converter,
|
||||
"XLMRobertaTokenizer": XLMRobertaConverter,
|
||||
|
@ -68,6 +68,7 @@ from . import (
|
||||
reformer,
|
||||
retribert,
|
||||
roberta,
|
||||
roformer,
|
||||
speech_to_text,
|
||||
squeezebert,
|
||||
t5,
|
||||
|
@ -68,6 +68,7 @@ from ..rag.configuration_rag import RagConfig
|
||||
from ..reformer.configuration_reformer import ReformerConfig
|
||||
from ..retribert.configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
|
||||
from ..roberta.configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||
from ..roformer.configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig
|
||||
from ..speech_to_text.configuration_speech_to_text import (
|
||||
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
Speech2TextConfig,
|
||||
@ -91,6 +92,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
(key, value)
|
||||
for pretrained_map in [
|
||||
# Add archive maps here
|
||||
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@ -146,6 +148,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
CONFIG_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
("roformer", RoFormerConfig),
|
||||
("clip", CLIPConfig),
|
||||
("bigbird_pegasus", BigBirdPegasusConfig),
|
||||
("deit", DeiTConfig),
|
||||
@ -207,6 +210,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
MODEL_NAMES_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add full (and cased) model names here
|
||||
("roformer", "RoFormer"),
|
||||
("clip", "CLIP"),
|
||||
("bigbird_pegasus", "BigBirdPegasus"),
|
||||
("deit", "DeiT"),
|
||||
|
@ -240,6 +240,15 @@ from ..roberta.modeling_roberta import (
|
||||
RobertaForTokenClassification,
|
||||
RobertaModel,
|
||||
)
|
||||
from ..roformer.modeling_roformer import (
|
||||
RoFormerForCausalLM,
|
||||
RoFormerForMaskedLM,
|
||||
RoFormerForMultipleChoice,
|
||||
RoFormerForQuestionAnswering,
|
||||
RoFormerForSequenceClassification,
|
||||
RoFormerForTokenClassification,
|
||||
RoFormerModel,
|
||||
)
|
||||
from ..speech_to_text.modeling_speech_to_text import Speech2TextForConditionalGeneration, Speech2TextModel
|
||||
from ..squeezebert.modeling_squeezebert import (
|
||||
SqueezeBertForMaskedLM,
|
||||
@ -334,6 +343,7 @@ from .configuration_auto import (
|
||||
ReformerConfig,
|
||||
RetriBertConfig,
|
||||
RobertaConfig,
|
||||
RoFormerConfig,
|
||||
Speech2TextConfig,
|
||||
SqueezeBertConfig,
|
||||
T5Config,
|
||||
@ -354,6 +364,7 @@ logger = logging.get_logger(__name__)
|
||||
MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
(RoFormerConfig, RoFormerModel),
|
||||
(CLIPConfig, CLIPModel),
|
||||
(BigBirdPegasusConfig, BigBirdPegasusModel),
|
||||
(DeiTConfig, DeiTModel),
|
||||
@ -451,6 +462,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model with LM heads mapping
|
||||
(RoFormerConfig, RoFormerForMaskedLM),
|
||||
(BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration),
|
||||
(GPTNeoConfig, GPTNeoForCausalLM),
|
||||
(BigBirdConfig, BigBirdForMaskedLM),
|
||||
@ -498,6 +510,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
(RoFormerConfig, RoFormerForCausalLM),
|
||||
(BigBirdPegasusConfig, BigBirdPegasusForCausalLM),
|
||||
(GPTNeoConfig, GPTNeoForCausalLM),
|
||||
(BigBirdConfig, BigBirdForCausalLM),
|
||||
@ -539,6 +552,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
(RoFormerConfig, RoFormerForMaskedLM),
|
||||
(BigBirdConfig, BigBirdForMaskedLM),
|
||||
(Wav2Vec2Config, Wav2Vec2ForMaskedLM),
|
||||
(ConvBertConfig, ConvBertForMaskedLM),
|
||||
@ -592,6 +606,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
(RoFormerConfig, RoFormerForSequenceClassification),
|
||||
(BigBirdPegasusConfig, BigBirdPegasusForSequenceClassification),
|
||||
(BigBirdConfig, BigBirdForSequenceClassification),
|
||||
(ConvBertConfig, ConvBertForSequenceClassification),
|
||||
@ -630,6 +645,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
(RoFormerConfig, RoFormerForQuestionAnswering),
|
||||
(BigBirdPegasusConfig, BigBirdPegasusForQuestionAnswering),
|
||||
(BigBirdConfig, BigBirdForQuestionAnswering),
|
||||
(ConvBertConfig, ConvBertForQuestionAnswering),
|
||||
@ -670,6 +686,7 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
(RoFormerConfig, RoFormerForTokenClassification),
|
||||
(BigBirdConfig, BigBirdForTokenClassification),
|
||||
(ConvBertConfig, ConvBertForTokenClassification),
|
||||
(LayoutLMConfig, LayoutLMForTokenClassification),
|
||||
@ -699,6 +716,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Multiple Choice mapping
|
||||
(RoFormerConfig, RoFormerForMultipleChoice),
|
||||
(BigBirdConfig, BigBirdForMultipleChoice),
|
||||
(ConvBertConfig, ConvBertForMultipleChoice),
|
||||
(CamembertConfig, CamembertForMultipleChoice),
|
||||
|
@ -148,6 +148,15 @@ from ..roberta.modeling_tf_roberta import (
|
||||
TFRobertaForTokenClassification,
|
||||
TFRobertaModel,
|
||||
)
|
||||
from ..roformer.modeling_tf_roformer import (
|
||||
TFRoFormerForCausalLM,
|
||||
TFRoFormerForMaskedLM,
|
||||
TFRoFormerForMultipleChoice,
|
||||
TFRoFormerForQuestionAnswering,
|
||||
TFRoFormerForSequenceClassification,
|
||||
TFRoFormerForTokenClassification,
|
||||
TFRoFormerModel,
|
||||
)
|
||||
from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
|
||||
from ..transfo_xl.modeling_tf_transfo_xl import (
|
||||
TFTransfoXLForSequenceClassification,
|
||||
@ -206,6 +215,7 @@ from .configuration_auto import (
|
||||
OpenAIGPTConfig,
|
||||
PegasusConfig,
|
||||
RobertaConfig,
|
||||
RoFormerConfig,
|
||||
T5Config,
|
||||
TransfoXLConfig,
|
||||
XLMConfig,
|
||||
@ -220,6 +230,7 @@ logger = logging.get_logger(__name__)
|
||||
TF_MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
(RoFormerConfig, TFRoFormerModel),
|
||||
(ConvBertConfig, TFConvBertModel),
|
||||
(LEDConfig, TFLEDModel),
|
||||
(LxmertConfig, TFLxmertModel),
|
||||
@ -285,6 +296,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model with LM heads mapping
|
||||
(RoFormerConfig, TFRoFormerForMaskedLM),
|
||||
(ConvBertConfig, TFConvBertForMaskedLM),
|
||||
(LEDConfig, TFLEDForConditionalGeneration),
|
||||
(T5Config, TFT5ForConditionalGeneration),
|
||||
@ -315,6 +327,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Causal LM mapping
|
||||
(RoFormerConfig, TFRoFormerForCausalLM),
|
||||
(BertConfig, TFBertLMHeadModel),
|
||||
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
||||
(GPT2Config, TFGPT2LMHeadModel),
|
||||
@ -331,6 +344,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
(RoFormerConfig, TFRoFormerForMaskedLM),
|
||||
(ConvBertConfig, TFConvBertForMaskedLM),
|
||||
(DistilBertConfig, TFDistilBertForMaskedLM),
|
||||
(AlbertConfig, TFAlbertForMaskedLM),
|
||||
@ -368,6 +382,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
(RoFormerConfig, TFRoFormerForSequenceClassification),
|
||||
(ConvBertConfig, TFConvBertForSequenceClassification),
|
||||
(DistilBertConfig, TFDistilBertForSequenceClassification),
|
||||
(AlbertConfig, TFAlbertForSequenceClassification),
|
||||
@ -394,6 +409,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
(RoFormerConfig, TFRoFormerForQuestionAnswering),
|
||||
(ConvBertConfig, TFConvBertForQuestionAnswering),
|
||||
(DistilBertConfig, TFDistilBertForQuestionAnswering),
|
||||
(AlbertConfig, TFAlbertForQuestionAnswering),
|
||||
@ -415,6 +431,7 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
(RoFormerConfig, TFRoFormerForTokenClassification),
|
||||
(ConvBertConfig, TFConvBertForTokenClassification),
|
||||
(DistilBertConfig, TFDistilBertForTokenClassification),
|
||||
(AlbertConfig, TFAlbertForTokenClassification),
|
||||
@ -437,6 +454,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Multiple Choice mapping
|
||||
(RoFormerConfig, TFRoFormerForMultipleChoice),
|
||||
(ConvBertConfig, TFConvBertForMultipleChoice),
|
||||
(CamembertConfig, TFCamembertForMultipleChoice),
|
||||
(XLMConfig, TFXLMForMultipleChoice),
|
||||
|
@ -51,6 +51,7 @@ from ..prophetnet.tokenization_prophetnet import ProphetNetTokenizer
|
||||
from ..rag.tokenization_rag import RagTokenizer
|
||||
from ..retribert.tokenization_retribert import RetriBertTokenizer
|
||||
from ..roberta.tokenization_roberta import RobertaTokenizer
|
||||
from ..roformer.tokenization_roformer import RoFormerTokenizer
|
||||
from ..squeezebert.tokenization_squeezebert import SqueezeBertTokenizer
|
||||
from ..tapas.tokenization_tapas import TapasTokenizer
|
||||
from ..transfo_xl.tokenization_transfo_xl import TransfoXLTokenizer
|
||||
@ -98,6 +99,7 @@ from .configuration_auto import (
|
||||
ReformerConfig,
|
||||
RetriBertConfig,
|
||||
RobertaConfig,
|
||||
RoFormerConfig,
|
||||
Speech2TextConfig,
|
||||
SqueezeBertConfig,
|
||||
T5Config,
|
||||
@ -228,6 +230,7 @@ logger = logging.get_logger(__name__)
|
||||
TOKENIZER_MAPPING = OrderedDict(
|
||||
[
|
||||
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
|
||||
(RoFormerConfig, (RoFormerTokenizer, None)),
|
||||
(T5Config, (T5Tokenizer, T5TokenizerFast)),
|
||||
(MT5Config, (MT5Tokenizer, MT5TokenizerFast)),
|
||||
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
|
||||
|
115
src/transformers/models/roformer/__init__.py
Normal file
115
src/transformers/models/roformer/__init__.py
Normal file
@ -0,0 +1,115 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig"],
|
||||
"tokenization_roformer": ["RoFormerTokenizer"],
|
||||
}
|
||||
|
||||
if is_tokenizers_available():
|
||||
_import_structure["tokenization_roformer_fast"] = ["RoFormerTokenizerFast"]
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_roformer"] = [
|
||||
"ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"RoFormerForCausalLM",
|
||||
"RoFormerForMaskedLM",
|
||||
"RoFormerForMultipleChoice",
|
||||
"RoFormerForQuestionAnswering",
|
||||
"RoFormerForSequenceClassification",
|
||||
"RoFormerForTokenClassification",
|
||||
"RoFormerLayer",
|
||||
"RoFormerModel",
|
||||
"RoFormerPreTrainedModel",
|
||||
"load_tf_weights_in_roformer",
|
||||
]
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_roformer"] = [
|
||||
"TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFRoFormerForCausalLM",
|
||||
"TFRoFormerForMaskedLM",
|
||||
"TFRoFormerForMultipleChoice",
|
||||
"TFRoFormerForQuestionAnswering",
|
||||
"TFRoFormerForSequenceClassification",
|
||||
"TFRoFormerForTokenClassification",
|
||||
"TFRoFormerLayer",
|
||||
"TFRoFormerModel",
|
||||
"TFRoFormerPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig
|
||||
from .tokenization_roformer import RoFormerTokenizer
|
||||
|
||||
if is_tokenizers_available():
|
||||
from .tokenization_roformer_fast import RoFormerTokenizerFast
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_roformer import (
|
||||
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
RoFormerForCausalLM,
|
||||
RoFormerForMaskedLM,
|
||||
RoFormerForMultipleChoice,
|
||||
RoFormerForQuestionAnswering,
|
||||
RoFormerForSequenceClassification,
|
||||
RoFormerForTokenClassification,
|
||||
RoFormerLayer,
|
||||
RoFormerModel,
|
||||
RoFormerPreTrainedModel,
|
||||
load_tf_weights_in_roformer,
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_roformer import (
|
||||
TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFRoFormerForCausalLM,
|
||||
TFRoFormerForMaskedLM,
|
||||
TFRoFormerForMultipleChoice,
|
||||
TFRoFormerForQuestionAnswering,
|
||||
TFRoFormerForSequenceClassification,
|
||||
TFRoFormerForTokenClassification,
|
||||
TFRoFormerLayer,
|
||||
TFRoFormerModel,
|
||||
TFRoFormerPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
class _LazyModule(_BaseLazyModule):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
"""
|
||||
|
||||
__file__ = globals()["__file__"]
|
||||
__path__ = [os.path.dirname(__file__)]
|
||||
|
||||
def _get_module(self, module_name: str):
|
||||
return importlib.import_module("." + module_name, self.__name__)
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
|
134
src/transformers/models/roformer/configuration_roformer.py
Normal file
134
src/transformers/models/roformer/configuration_roformer.py
Normal file
@ -0,0 +1,134 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" RoFormer model configuration """
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/config.json",
|
||||
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json"
|
||||
# See all RoFormer models at https://huggingface.co/models?filter=roformer
|
||||
}
|
||||
|
||||
|
||||
class RoFormerConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.RoFormerModel`. It is used to
|
||||
instantiate an RoFormer model according to the specified arguments, defining the model architecture. Instantiating
|
||||
a configuration with the defaults will yield a similar configuration to that of the RoFormer
|
||||
`junnyu/roformer_chinese_base <https://huggingface.co/junnyu/roformer_chinese_base>`__ architecture.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 50000):
|
||||
Vocabulary size of the RoFormer model. Defines the number of different tokens that can be represented by
|
||||
the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or
|
||||
:class:`~transformers.TFRoFormerModel`.
|
||||
embedding_size (:obj:`int`, `optional`, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
||||
Dimension of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 1536):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 1536).
|
||||
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
|
||||
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.RoFormerModel`
|
||||
or :class:`~transformers.TFRoFormerModel`.
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if ``config.is_decoder=True``.
|
||||
rotary_value (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not apply rotary position embeddings on value layer.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import RoFormerModel, RoFormerConfig
|
||||
|
||||
>>> # Initializing a RoFormer junnyu/roformer_chinese_base style configuration
|
||||
>>> configuration = RoFormerConfig()
|
||||
|
||||
>>> # Initializing a model from the junnyu/roformer_chinese_base style configuration
|
||||
>>> model = RoFormerModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
"""
|
||||
model_type = "roformer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50000,
|
||||
embedding_size=768,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=1536,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
pad_token_id=0,
|
||||
gradient_checkpointing=False,
|
||||
rotary_value=False,
|
||||
use_cache=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
self.rotary_value = rotary_value
|
||||
self.use_cache = use_cache
|
@ -0,0 +1,61 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert RoFormer checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import RoFormerConfig, RoFormerForMaskedLM, load_tf_weights_in_roformer
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = RoFormerConfig.from_json_file(bert_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = RoFormerForMaskedLM(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_roformer(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
torch.save(model.state_dict(), pytorch_dump_path, _use_new_zipfile_serialization=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
1575
src/transformers/models/roformer/modeling_roformer.py
Normal file
1575
src/transformers/models/roformer/modeling_roformer.py
Normal file
File diff suppressed because it is too large
Load Diff
1523
src/transformers/models/roformer/modeling_tf_roformer.py
Normal file
1523
src/transformers/models/roformer/modeling_tf_roformer.py
Normal file
File diff suppressed because it is too large
Load Diff
317
src/transformers/models/roformer/tokenization_roformer.py
Normal file
317
src/transformers/models/roformer/tokenization_roformer.py
Normal file
@ -0,0 +1,317 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for RoFormer."""
|
||||
|
||||
import collections
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer, load_vocab
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
|
||||
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536}
|
||||
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"junnyu/roformer_chinese_small": {"do_lower_case": True},
|
||||
"junnyu/roformer_chinese_base": {"do_lower_case": True},
|
||||
}
|
||||
|
||||
|
||||
class RoFormerTokenizer(PreTrainedTokenizer):
|
||||
r"""
|
||||
Construct a RoFormer tokenizer. Based on `Rust Jieba <https://pypi.org/project/rjieba/>`.
|
||||
|
||||
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
|
||||
Users should refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (:obj:`str`):
|
||||
File containing the vocabulary.
|
||||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to lowercase the input when tokenizing.
|
||||
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to do basic tokenization before WordPiece.
|
||||
never_split (:obj:`Iterable`, `optional`):
|
||||
Collection of tokens which will never be split during tokenization. Only has an effect when
|
||||
:obj:`do_basic_tokenize=True`
|
||||
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
|
||||
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||
token of a sequence built with special tokens.
|
||||
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
|
||||
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
||||
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
||||
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
|
||||
The token used for masking values. This is the token used when training this model with masked language
|
||||
modeling. This is the token which the model will try to predict.
|
||||
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to tokenize Chinese characters.
|
||||
|
||||
This should likely be deactivated for Japanese (see this `issue
|
||||
<https://github.com/huggingface/transformers/issues/328>`__).
|
||||
strip_accents: (:obj:`bool`, `optional`):
|
||||
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
||||
value for :obj:`lowercase` (as in the original BERT).
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import RoFormerTokenizer
|
||||
>>> tokenizer = RoFormerTokenizer.from_pretrained('junnyu/roformer_chinese_base')
|
||||
>>> tokenizer.tokenize("今天天气非常好。")
|
||||
# ['今', '天', '天', '气', '非常', '好', '。']
|
||||
|
||||
"""
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
do_lower_case=True,
|
||||
do_basic_tokenize=True,
|
||||
never_split=None,
|
||||
unk_token="[UNK]",
|
||||
sep_token="[SEP]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
mask_token="[MASK]",
|
||||
tokenize_chinese_chars=True,
|
||||
strip_accents=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
do_lower_case=do_lower_case,
|
||||
do_basic_tokenize=do_basic_tokenize,
|
||||
never_split=never_split,
|
||||
unk_token=unk_token,
|
||||
sep_token=sep_token,
|
||||
pad_token=pad_token,
|
||||
cls_token=cls_token,
|
||||
mask_token=mask_token,
|
||||
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||
strip_accents=strip_accents,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
|
||||
"model use `tokenizer = RoFormerTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
||||
self.do_basic_tokenize = do_basic_tokenize
|
||||
if do_basic_tokenize:
|
||||
self.basic_tokenizer = BasicTokenizer(
|
||||
do_lower_case=do_lower_case,
|
||||
never_split=never_split,
|
||||
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||
strip_accents=strip_accents,
|
||||
)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
||||
try:
|
||||
import rjieba
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You need to install rjieba to use RoFormerTokenizer."
|
||||
"See https://pypi.org/project/rjieba/ for installation."
|
||||
)
|
||||
self.jieba = rjieba
|
||||
|
||||
@property
|
||||
def do_lower_case(self):
|
||||
return self.basic_tokenizer.do_lower_case
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["jieba"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.__dict__ = d
|
||||
try:
|
||||
import rjieba
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You need to install rjieba to use RoFormerTokenizer."
|
||||
"See https://pypi.org/project/rjieba/ for installation."
|
||||
)
|
||||
self.jieba = rjieba
|
||||
|
||||
def get_vocab(self):
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
def _tokenize(self, text, use_jieba=True):
|
||||
split_tokens = []
|
||||
if use_jieba:
|
||||
for wholword in self.jieba.cut(text, False):
|
||||
if wholword in self.vocab:
|
||||
split_tokens.append(wholword)
|
||||
else:
|
||||
# use bert tokenizer to _tokenize
|
||||
char_list = self._tokenize(wholword, use_jieba=False)
|
||||
split_tokens.extend(char_list)
|
||||
else:
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
split_tokens.append(token)
|
||||
else:
|
||||
split_tokens += self.wordpiece_tokenizer.tokenize(token)
|
||||
else:
|
||||
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
||||
return split_tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.ids_to_tokens.get(index, self.unk_token)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
out_string = " ".join(tokens).replace(" ##", "").strip()
|
||||
return out_string
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A RoFormer sequence has the following format:
|
||||
|
||||
- single sequence: ``[CLS] X [SEP]``
|
||||
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
sep = [self.sep_token_id]
|
||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` method.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the token list is already formatted with special tokens for the model.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
||||
)
|
||||
|
||||
if token_ids_1 is not None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A RoFormer
|
||||
sequence pair mask has the following format:
|
||||
|
||||
::
|
||||
|
||||
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
||||
| first sequence | second sequence |
|
||||
|
||||
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
||||
sequence(s).
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
index = 0
|
||||
if os.path.isdir(save_directory):
|
||||
vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
)
|
||||
else:
|
||||
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning(
|
||||
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
|
||||
" Please check that the vocabulary is not corrupted!"
|
||||
)
|
||||
index = token_index
|
||||
writer.write(token + "\n")
|
||||
index += 1
|
||||
return (vocab_file,)
|
191
src/transformers/models/roformer/tokenization_roformer_fast.py
Normal file
191
src/transformers/models/roformer/tokenization_roformer_fast.py
Normal file
@ -0,0 +1,191 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for RoFormer."""
|
||||
import json
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from tokenizers import normalizers
|
||||
from tokenizers.pre_tokenizers import BertPreTokenizer, PreTokenizer
|
||||
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
from .tokenization_roformer import RoFormerTokenizer
|
||||
from .tokenization_utils import JiebaPreTokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
|
||||
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536}
|
||||
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"junnyu/roformer_chinese_small": {"do_lower_case": True},
|
||||
"junnyu/roformer_chinese_base": {"do_lower_case": True},
|
||||
}
|
||||
|
||||
|
||||
class RoFormerTokenizerFast(PreTrainedTokenizerFast):
|
||||
r"""
|
||||
Construct a "fast" RoFormer tokenizer (backed by HuggingFace's `tokenizers` library).
|
||||
|
||||
:class:`~transformers.RoFormerTokenizerFast` is almost identical to :class:`~transformers.BertTokenizerFast` and
|
||||
runs end-to-end tokenization: punctuation splitting and wordpiece. There are some difference between them when
|
||||
tokenizing Chinese.
|
||||
|
||||
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
|
||||
methods. Users should refer to this superclass for more information regarding those methods.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import RoFormerTokenizerFast
|
||||
>>> tokenizer = RoFormerTokenizerFast.from_pretrained('junnyu/roformer_chinese_base')
|
||||
>>> tokenizer.tokenize("今天天气非常好。")
|
||||
# ['今', '天', '天', '气', '非常', '好', '。']
|
||||
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
slow_tokenizer_class = RoFormerTokenizer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
tokenizer_file=None,
|
||||
do_lower_case=True,
|
||||
unk_token="[UNK]",
|
||||
sep_token="[SEP]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
mask_token="[MASK]",
|
||||
tokenize_chinese_chars=True,
|
||||
strip_accents=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
vocab_file,
|
||||
tokenizer_file=tokenizer_file,
|
||||
do_lower_case=do_lower_case,
|
||||
unk_token=unk_token,
|
||||
sep_token=sep_token,
|
||||
pad_token=pad_token,
|
||||
cls_token=cls_token,
|
||||
mask_token=mask_token,
|
||||
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||
strip_accents=strip_accents,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
||||
if (
|
||||
pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
|
||||
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
|
||||
):
|
||||
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
|
||||
pre_tok_state["lowercase"] = do_lower_case
|
||||
pre_tok_state["strip_accents"] = strip_accents
|
||||
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
|
||||
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["_tokenizer"].pre_tokenizer = BertPreTokenizer()
|
||||
return state
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.__dict__ = d
|
||||
vocab = self.__dict__["_tokenizer"].get_vocab()
|
||||
self.__dict__["_tokenizer"].pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer(vocab))
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A RoFormer sequence has the following format:
|
||||
|
||||
- single sequence: ``[CLS] X [SEP]``
|
||||
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||
|
||||
if token_ids_1:
|
||||
output += token_ids_1 + [self.sep_token_id]
|
||||
|
||||
return output
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A RoFormer
|
||||
sequence pair mask has the following format:
|
||||
|
||||
::
|
||||
|
||||
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
||||
| first sequence | second sequence |
|
||||
|
||||
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
||||
sequence(s).
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||
return tuple(files)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory,
|
||||
legacy_format=None,
|
||||
filename_prefix=None,
|
||||
push_to_hub=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.backend_tokenizer.pre_tokenizer = BertPreTokenizer()
|
||||
return super().save_pretrained(save_directory, legacy_format, filename_prefix, push_to_hub, **kwargs)
|
68
src/transformers/models/roformer/tokenization_utils.py
Normal file
68
src/transformers/models/roformer/tokenization_utils.py
Normal file
@ -0,0 +1,68 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization utils for RoFormer."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from tokenizers import NormalizedString, PreTokenizedString, normalizers
|
||||
|
||||
|
||||
class JiebaPreTokenizer:
|
||||
def __init__(self, vocab) -> None:
|
||||
self.vocab = vocab
|
||||
self.normalizers = normalizers.BertNormalizer(
|
||||
clean_text=False,
|
||||
handle_chinese_chars=True,
|
||||
strip_accents=False,
|
||||
lowercase=False,
|
||||
)
|
||||
try:
|
||||
import rjieba
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You need to install rjieba to use RoFormerTokenizer."
|
||||
"See https://pypi.org/project/rjieba/ for installation."
|
||||
)
|
||||
self.jieba = rjieba
|
||||
|
||||
def jieba_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
|
||||
splits = []
|
||||
|
||||
# this code slice normalized_string is too slow (6s) but test_alignement_methods can pass
|
||||
# for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False):
|
||||
# if token in self.vocab:
|
||||
# splits.append(normalized_string.slice((start, end)))
|
||||
# else:
|
||||
# token_list = self.normalizers.normalize_str(token).split()
|
||||
# for token in token_list:
|
||||
# if token:
|
||||
# end = start + len(token)
|
||||
# splits.append(normalized_string.slice((start, end)))
|
||||
# start = end
|
||||
|
||||
# this code test_alignement_methods can't pass but fast (300ms)
|
||||
for token in self.jieba.cut(str(normalized_string), False):
|
||||
if token in self.vocab:
|
||||
splits.append(NormalizedString(token))
|
||||
else:
|
||||
token_list = self.normalizers.normalize_str(token).split()
|
||||
for token in token_list:
|
||||
if token:
|
||||
splits.append(NormalizedString(token))
|
||||
|
||||
return splits
|
||||
|
||||
def pre_tokenize(self, pretok: PreTokenizedString):
|
||||
pretok.split(self.jieba_split)
|
@ -2553,6 +2553,86 @@ class RobertaModel:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class RoFormerForCausalLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RoFormerForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RoFormerForMultipleChoice:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RoFormerForQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RoFormerForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RoFormerForTokenClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RoFormerLayer:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RoFormerModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RoFormerPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
def load_tf_weights_in_roformer(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_roformer, ["torch"])
|
||||
|
||||
|
||||
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -1479,6 +1479,82 @@ class TFRobertaPreTrainedModel:
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class TFRoFormerForCausalLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFRoFormerForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFRoFormerForMultipleChoice:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFRoFormerForQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFRoFormerForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFRoFormerForTokenClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFRoFormerLayer:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFRoFormerModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFRoFormerPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -281,6 +281,15 @@ class RobertaTokenizerFast:
|
||||
requires_backends(self, ["tokenizers"])
|
||||
|
||||
|
||||
class RoFormerTokenizerFast:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tokenizers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_backends(self, ["tokenizers"])
|
||||
|
||||
|
||||
class SqueezeBertTokenizerFast:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tokenizers"])
|
||||
|
@ -6,6 +6,7 @@ from collections import OrderedDict
|
||||
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("RoFormerConfig", "RoFormerForQuestionAnswering"),
|
||||
("BigBirdPegasusConfig", "BigBirdPegasusForQuestionAnswering"),
|
||||
("BigBirdConfig", "BigBirdForQuestionAnswering"),
|
||||
("ConvBertConfig", "ConvBertForQuestionAnswering"),
|
||||
|
556
tests/test_modeling_roformer.py
Normal file
556
tests/test_modeling_roformer.py
Normal file
@ -0,0 +1,556 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch RoFormer model. """
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from tests.test_modeling_common import floats_tensor
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
RoFormerConfig,
|
||||
RoFormerForCausalLM,
|
||||
RoFormerForMaskedLM,
|
||||
RoFormerForMultipleChoice,
|
||||
RoFormerForQuestionAnswering,
|
||||
RoFormerForSequenceClassification,
|
||||
RoFormerForTokenClassification,
|
||||
RoFormerModel,
|
||||
)
|
||||
from transformers.models.roformer.modeling_roformer import (
|
||||
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
RoFormerSelfAttention,
|
||||
RoFormerSinusoidalPositionalEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class RoFormerModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = RoFormerConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = self.prepare_config_and_inputs()
|
||||
|
||||
config.is_decoder = True
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = RoFormerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = RoFormerModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = RoFormerForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = RoFormerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.is_decoder = True
|
||||
config.add_cross_attention = True
|
||||
model = RoFormerForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
outputs = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(
|
||||
next_input_ids,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
output_from_past = model(
|
||||
next_tokens,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = RoFormerForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = RoFormerForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = RoFormerForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = RoFormerForMultipleChoice(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class RoFormerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
RoFormerModel,
|
||||
RoFormerForMaskedLM,
|
||||
RoFormerForCausalLM,
|
||||
RoFormerForMultipleChoice,
|
||||
RoFormerForQuestionAnswering,
|
||||
RoFormerForSequenceClassification,
|
||||
RoFormerForTokenClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (RoFormerForCausalLM,) if is_torch_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RoFormerModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=RoFormerConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_model_as_decoder(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
||||
|
||||
def test_model_as_decoder_with_default_input_mask(self):
|
||||
# This regression test was failing with PyTorch < 1.3
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
|
||||
input_mask = None
|
||||
|
||||
self.model_tester.create_and_check_model_as_decoder(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = RoFormerModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
class RoFormerModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
|
||||
# TODO Replace vocab size
|
||||
vocab_size = 50000
|
||||
|
||||
expected_shape = torch.Size((1, 6, vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
# TODO Replace values below with what was printed above.
|
||||
expected_slice = torch.tensor(
|
||||
[[[-0.1205, -1.0265, 0.2922], [-1.5134, 0.1974, 0.1519], [-5.0135, -3.9003, -0.8404]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
|
||||
@require_torch
|
||||
class RoFormerSinusoidalPositionalEmbeddingTest(unittest.TestCase):
|
||||
tolerance = 1e-4
|
||||
|
||||
def test_basic(self):
|
||||
input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device)
|
||||
emb1 = RoFormerSinusoidalPositionalEmbedding(num_positions=6, embedding_dim=6).to(torch_device)
|
||||
emb = emb1(input_ids.shape)
|
||||
desired_weights = torch.tensor(
|
||||
[[0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 1.0000], [0.8415, 0.0464, 0.0022, 0.5403, 0.9989, 1.0000]]
|
||||
).to(torch_device)
|
||||
self.assertTrue(
|
||||
torch.allclose(emb, desired_weights, atol=self.tolerance),
|
||||
msg=f"\nexp:\n{desired_weights}\ngot:\n{emb[0]}\n",
|
||||
)
|
||||
|
||||
def test_positional_emb_weights_against_roformer(self):
|
||||
|
||||
desired_weights = torch.tensor(
|
||||
[
|
||||
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.8415, 0.8219, 0.8020, 0.7819, 0.7617],
|
||||
[0.9093, 0.9364, 0.9581, 0.9749, 0.9870],
|
||||
]
|
||||
).to(torch_device)
|
||||
emb1 = RoFormerSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512).to(torch_device)
|
||||
weights = emb1.weight.data[:3, :5].to(torch_device)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(weights, desired_weights, atol=self.tolerance),
|
||||
msg=f"\nexp:\n{desired_weights}\ngot:\n{weights}\n",
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class RoFormerSelfAttentionRotaryPositionEmbeddingTest(unittest.TestCase):
|
||||
tolerance = 1e-4
|
||||
|
||||
def test_apply_rotary_position_embeddings(self):
|
||||
# 2,12,16,64
|
||||
query_layer = (
|
||||
torch.arange(2 * 12 * 16 * 64, dtype=torch.float, device=torch_device).reshape(2, 12, 16, 64) / 100
|
||||
).to(torch_device)
|
||||
key_layer = (
|
||||
-torch.arange(2 * 12 * 16 * 64, dtype=torch.float, device=torch_device).reshape(2, 12, 16, 64) / 100
|
||||
).to(torch_device)
|
||||
embed_positions = RoFormerSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=64).to(torch_device)
|
||||
sinusoidal_pos = embed_positions([2, 16, 768])[None, None, :, :]
|
||||
|
||||
query_layer, key_layer = RoFormerSelfAttention.apply_rotary_position_embeddings(
|
||||
sinusoidal_pos, query_layer, key_layer
|
||||
)
|
||||
|
||||
desired_query_layer = torch.tensor(
|
||||
[
|
||||
[0.0000, 0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700],
|
||||
[-0.2012, 0.8897, 0.0263, 0.9401, 0.2074, 0.9463, 0.3481, 0.9343],
|
||||
[-1.7057, 0.6271, -1.2145, 1.3897, -0.6303, 1.7647, -0.1173, 1.8985],
|
||||
[-2.1731, -1.6397, -2.7358, 0.2854, -2.1840, 1.7183, -1.3018, 2.4871],
|
||||
[0.2717, -3.6173, -2.9206, -2.1988, -3.6638, 0.3858, -2.9155, 2.2980],
|
||||
[3.9859, -2.1580, -0.7984, -4.4904, -4.1181, -2.0252, -4.4782, 1.1253],
|
||||
]
|
||||
).to(torch_device)
|
||||
desired_key_layer = torch.tensor(
|
||||
[
|
||||
[0.0000, -0.0100, -0.0200, -0.0300, -0.0400, -0.0500, -0.0600, -0.0700],
|
||||
[0.2012, -0.8897, -0.0263, -0.9401, -0.2074, -0.9463, -0.3481, -0.9343],
|
||||
[1.7057, -0.6271, 1.2145, -1.3897, 0.6303, -1.7647, 0.1173, -1.8985],
|
||||
[2.1731, 1.6397, 2.7358, -0.2854, 2.1840, -1.7183, 1.3018, -2.4871],
|
||||
[-0.2717, 3.6173, 2.9206, 2.1988, 3.6638, -0.3858, 2.9155, -2.2980],
|
||||
[-3.9859, 2.1580, 0.7984, 4.4904, 4.1181, 2.0252, 4.4782, -1.1253],
|
||||
]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(query_layer[0, 0, :6, :8], desired_query_layer, atol=self.tolerance),
|
||||
msg=f"\nexp:\n{desired_query_layer}\ngot:\n{query_layer}\n",
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(key_layer[0, 0, :6, :8], desired_key_layer, atol=self.tolerance),
|
||||
msg=f"\nexp:\n{desired_key_layer}\ngot:\n{key_layer}\n",
|
||||
)
|
401
tests/test_modeling_tf_roformer.py
Normal file
401
tests/test_modeling_tf_roformer.py
Normal file
@ -0,0 +1,401 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import RoFormerConfig, is_tf_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import (
|
||||
TFRoFormerForCausalLM,
|
||||
TFRoFormerForMaskedLM,
|
||||
TFRoFormerForMultipleChoice,
|
||||
TFRoFormerForQuestionAnswering,
|
||||
TFRoFormerForSequenceClassification,
|
||||
TFRoFormerForTokenClassification,
|
||||
TFRoFormerModel,
|
||||
)
|
||||
from transformers.models.roformer.modeling_tf_roformer import (
|
||||
TFRoFormerSelfAttention,
|
||||
TFRoFormerSinusoidalPositionalEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class TFRoFormerModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = 13
|
||||
self.seq_length = 7
|
||||
self.is_training = True
|
||||
self.use_input_mask = True
|
||||
self.use_token_type_ids = True
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
self.max_position_embeddings = 512
|
||||
self.type_vocab_size = 16
|
||||
self.type_sequence_label_size = 2
|
||||
self.initializer_range = 0.02
|
||||
self.num_labels = 3
|
||||
self.num_choices = 4
|
||||
self.scope = None
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = RoFormerConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFRoFormerModel(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
|
||||
inputs = [input_ids, input_mask]
|
||||
result = model(inputs)
|
||||
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_lm_head(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.is_decoder = True
|
||||
model = TFRoFormerForCausalLM(config=config)
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
prediction_scores = model(inputs)["logits"]
|
||||
self.parent.assertListEqual(
|
||||
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
|
||||
def create_and_check_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFRoFormerForMaskedLM(config=config)
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFRoFormerForSequenceClassification(config=config)
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = TFRoFormerForMultipleChoice(config=config)
|
||||
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
|
||||
inputs = {
|
||||
"input_ids": multiple_choice_inputs_ids,
|
||||
"attention_mask": multiple_choice_input_mask,
|
||||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def create_and_check_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFRoFormerForTokenClassification(config=config)
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFRoFormerForQuestionAnswering(config=config)
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFRoFormerModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
TFRoFormerModel,
|
||||
TFRoFormerForCausalLM,
|
||||
TFRoFormerForMaskedLM,
|
||||
TFRoFormerForQuestionAnswering,
|
||||
TFRoFormerForSequenceClassification,
|
||||
TFRoFormerForTokenClassification,
|
||||
TFRoFormerForMultipleChoice,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFRoFormerModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=RoFormerConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_causal_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_lm_head(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = TFRoFormerModel.from_pretrained("junnyu/roformer_chinese_base")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFRoFormerModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = TFRoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
|
||||
input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
|
||||
# TODO Replace vocab size
|
||||
vocab_size = 50000
|
||||
|
||||
expected_shape = [1, 6, vocab_size]
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
print(output[:, :3, :3])
|
||||
|
||||
# TODO Replace values below with what was printed above.
|
||||
expected_slice = tf.constant(
|
||||
[
|
||||
[
|
||||
[-0.12053341, -1.0264901, 0.29221946],
|
||||
[-1.5133783, 0.197433, 0.15190607],
|
||||
[-5.0135403, -3.900256, -0.84038764],
|
||||
]
|
||||
]
|
||||
)
|
||||
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFRoFormerSinusoidalPositionalEmbeddingTest(unittest.TestCase):
|
||||
tolerance = 1e-4
|
||||
|
||||
def test_basic(self):
|
||||
input_ids = tf.constant([[4, 10]])
|
||||
emb1 = TFRoFormerSinusoidalPositionalEmbedding(num_positions=6, embedding_dim=6)
|
||||
|
||||
emb = emb1(input_ids.shape)
|
||||
desired_weights = tf.constant(
|
||||
[[0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 1.0000], [0.8415, 0.0464, 0.0022, 0.5403, 0.9989, 1.0000]]
|
||||
)
|
||||
|
||||
tf.debugging.assert_near(emb, desired_weights, atol=self.tolerance)
|
||||
|
||||
def test_positional_emb_weights_against_roformer(self):
|
||||
|
||||
desired_weights = tf.constant(
|
||||
[
|
||||
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.8415, 0.8219, 0.8020, 0.7819, 0.7617],
|
||||
[0.9093, 0.9364, 0.9581, 0.9749, 0.9870],
|
||||
]
|
||||
)
|
||||
emb1 = TFRoFormerSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512)
|
||||
emb1([2, 16, 512])
|
||||
weights = emb1.weight[:3, :5]
|
||||
|
||||
tf.debugging.assert_near(weights, desired_weights, atol=self.tolerance)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFRoFormerSelfAttentionRotaryPositionEmbeddingTest(unittest.TestCase):
|
||||
tolerance = 1e-4
|
||||
|
||||
def test_apply_rotary_position_embeddings(self):
|
||||
# 2,12,16,64
|
||||
query_layer = tf.reshape(tf.range(2 * 12 * 16 * 64, dtype=tf.float32), shape=(2, 12, 16, 64)) / 100
|
||||
|
||||
key_layer = -tf.reshape(tf.range(2 * 12 * 16 * 64, dtype=tf.float32), shape=(2, 12, 16, 64)) / 100
|
||||
|
||||
embed_positions = TFRoFormerSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=64)
|
||||
sinusoidal_pos = embed_positions([2, 16, 768])[None, None, :, :]
|
||||
|
||||
query_layer, key_layer = TFRoFormerSelfAttention.apply_rotary_position_embeddings(
|
||||
sinusoidal_pos, query_layer, key_layer
|
||||
)
|
||||
|
||||
desired_query_layer = tf.constant(
|
||||
[
|
||||
[0.0000, 0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700],
|
||||
[-0.2012, 0.8897, 0.0263, 0.9401, 0.2074, 0.9463, 0.3481, 0.9343],
|
||||
[-1.7057, 0.6271, -1.2145, 1.3897, -0.6303, 1.7647, -0.1173, 1.8985],
|
||||
[-2.1731, -1.6397, -2.7358, 0.2854, -2.1840, 1.7183, -1.3018, 2.4871],
|
||||
[0.2717, -3.6173, -2.9206, -2.1988, -3.6638, 0.3858, -2.9155, 2.2980],
|
||||
[3.9859, -2.1580, -0.7984, -4.4904, -4.1181, -2.0252, -4.4782, 1.1253],
|
||||
]
|
||||
)
|
||||
desired_key_layer = tf.constant(
|
||||
[
|
||||
[0.0000, -0.0100, -0.0200, -0.0300, -0.0400, -0.0500, -0.0600, -0.0700],
|
||||
[0.2012, -0.8897, -0.0263, -0.9401, -0.2074, -0.9463, -0.3481, -0.9343],
|
||||
[1.7057, -0.6271, 1.2145, -1.3897, 0.6303, -1.7647, 0.1173, -1.8985],
|
||||
[2.1731, 1.6397, 2.7358, -0.2854, 2.1840, -1.7183, 1.3018, -2.4871],
|
||||
[-0.2717, 3.6173, 2.9206, 2.1988, 3.6638, -0.3858, 2.9155, -2.2980],
|
||||
[-3.9859, 2.1580, 0.7984, 4.4904, 4.1181, 2.0252, 4.4782, -1.1253],
|
||||
]
|
||||
)
|
||||
|
||||
tf.debugging.assert_near(query_layer[0, 0, :6, :8], desired_query_layer, atol=self.tolerance)
|
||||
tf.debugging.assert_near(key_layer[0, 0, :6, :8], desired_key_layer, atol=self.tolerance)
|
84
tests/test_tokenization_roformer.py
Normal file
84
tests/test_tokenization_roformer.py
Normal file
@ -0,0 +1,84 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import unittest
|
||||
|
||||
from transformers import RoFormerTokenizer, RoFormerTokenizerFast
|
||||
from transformers.testing_utils import require_tokenizers
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
def is_rjieba_available():
|
||||
return importlib.util.find_spec("rjieba") is not None
|
||||
|
||||
|
||||
def require_rjieba(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires Jieba. These tests are skipped when Jieba isn't installed.
|
||||
"""
|
||||
if not is_rjieba_available():
|
||||
return unittest.skip("test requires rjieba")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
@require_rjieba
|
||||
@require_tokenizers
|
||||
class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = RoFormerTokenizer
|
||||
rust_tokenizer_class = RoFormerTokenizerFast
|
||||
space_between_special_tokens = True
|
||||
test_rust_tokenizer = True
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return self.tokenizer_class.from_pretrained("junnyu/roformer_chinese_base", **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return self.rust_tokenizer_class.from_pretrained("junnyu/roformer_chinese_base", **kwargs)
|
||||
|
||||
def get_chinese_input_output_texts(self):
|
||||
input_text = "永和服装饰品有限公司,今天天气非常好"
|
||||
output_text = "永和 服装 饰品 有限公司 , 今 天 天 气 非常 好"
|
||||
return input_text, output_text
|
||||
|
||||
def test_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
input_text, output_text = self.get_chinese_input_output_texts()
|
||||
tokens = tokenizer.tokenize(input_text)
|
||||
|
||||
self.assertListEqual(tokens, output_text.split())
|
||||
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens)
|
||||
|
||||
def test_rust_tokenizer(self):
|
||||
tokenizer = self.get_rust_tokenizer()
|
||||
input_text, output_text = self.get_chinese_input_output_texts()
|
||||
tokens = tokenizer.tokenize(input_text)
|
||||
self.assertListEqual(tokens, output_text.split())
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens)
|
||||
|
||||
# due to custom pre_tokenize , char_to_token may be error
|
||||
def test_alignement_methods(self):
|
||||
pass
|
Loading…
Reference in New Issue
Block a user