mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Merge pull request #964 from huggingface/RoBERTa
RoBERTa: model conversion, inference, tests 🔥
This commit is contained in:
commit
88efc65bac
@ -12,6 +12,7 @@ The library currently contains PyTorch implementations, pre-trained model weight
|
||||
4. **[Transformer-XL](https://github.com/kimiyoung/transformer-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
5. **[XLNet](https://github.com/zihangdai/xlnet/)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
6. **[XLM](https://github.com/facebookresearch/XLM/)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
|
||||
7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du et al.
|
||||
|
||||
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/pytorch-transformers/examples.html).
|
||||
|
||||
|
@ -47,3 +47,4 @@ The library currently contains PyTorch implementations, pre-trained model weight
|
||||
model_doc/gpt2
|
||||
model_doc/xlm
|
||||
model_doc/xlnet
|
||||
model_doc/roberta
|
||||
|
36
docs/source/model_doc/roberta.rst
Normal file
36
docs/source/model_doc/roberta.rst
Normal file
@ -0,0 +1,36 @@
|
||||
RoBERTa
|
||||
----------------------------------------------------
|
||||
|
||||
``RobertaConfig``
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: pytorch_transformers.RobertaConfig
|
||||
:members:
|
||||
|
||||
|
||||
``RobertaTokenizer``
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: pytorch_transformers.RobertaTokenizer
|
||||
:members:
|
||||
|
||||
|
||||
``RobertaModel``
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: pytorch_transformers.RobertaModel
|
||||
:members:
|
||||
|
||||
|
||||
``RobertaForMaskedLM``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: pytorch_transformers.RobertaForMaskedLM
|
||||
:members:
|
||||
|
||||
|
||||
``RobertaForSequenceClassification``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: pytorch_transformers.RobertaForSequenceClassification
|
||||
:members:
|
@ -4,97 +4,109 @@ Pretrained models
|
||||
Here is the full list of the currently provided pretrained models together with a short presentation of each model.
|
||||
|
||||
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| Architecture | Shortcut name | Details of the model |
|
||||
+===================+============================================================+===========================================================================================================================+
|
||||
| BERT | ``bert-base-uncased`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
|
||||
| | | Trained on lower-cased English text |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-large-uncased`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
|
||||
| | | Trained on lower-cased English text |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-cased`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
|
||||
| | | Trained on cased English text |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-large-cased`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
|
||||
| | | Trained on cased English text |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-multilingual-uncased`` | (Original, not recommended) 12-layer, 768-hidden, 12-heads, 110M parameters |
|
||||
| | | Trained on lower-cased text in the top 102 languages with the largest Wikipedias |
|
||||
| | | (see `details <https://github.com/google-research/bert/blob/master/multilingual.md>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-multilingual-cased`` | (New, **recommended**) 12-layer, 768-hidden, 12-heads, 110M parameters |
|
||||
| | | Trained on cased text in the top 104 languages with the largest Wikipedias |
|
||||
| | | (see `details <https://github.com/google-research/bert/blob/master/multilingual.md>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-chinese`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
|
||||
| | | Trained on cased Chinese Simplified and Traditional text |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-german-cased`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
|
||||
| | | Trained on cased German text by Deepset.ai |
|
||||
| | | (see `details on deepset.ai website <https://deepset.ai/german-bert>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-large-uncased-whole-word-masking`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
|
||||
| | | Trained on lower-cased English text using Whole-Word-Masking |
|
||||
| | | (see `details <https://github.com/google-research/bert/#bert>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-large-cased-whole-word-masking`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
|
||||
| | | Trained on cased English text using Whole-Word-Masking |
|
||||
| | | (see `details <https://github.com/google-research/bert/#bert>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-large-uncased-whole-word-masking-finetuned-squad`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
|
||||
| | | The ``bert-large-uncased-whole-word-masking`` model fine-tuned on SQuAD (see details of fine-tuning in the |
|
||||
| | | `example section <https://github.com/huggingface/pytorch-transformers/tree/master/examples>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-large-cased-whole-word-masking-finetuned-squad`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
|
||||
| | | The ``bert-large-cased-whole-word-masking`` model fine-tuned on SQuAD |
|
||||
| | | (see `details of fine-tuning in the example section <https://huggingface.co/pytorch-transformers/examples.html>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-cased-finetuned-mrpc`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
|
||||
| | | The ``bert-base-cased`` model fine-tuned on MRPC |
|
||||
| | | (see `details of fine-tuning in the example section <https://huggingface.co/pytorch-transformers/examples.html>`__) |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| GPT | ``openai-gpt`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
|
||||
| | | OpenAI GPT English model |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| GPT-2 | ``gpt2`` | 12-layer, 768-hidden, 12-heads, 117M parameters |
|
||||
| | | OpenAI GPT-2 English model |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``gpt2-medium`` | 24-layer, 1024-hidden, 16-heads, 345M parameters |
|
||||
| | | OpenAI's Medium-sized GPT-2 English model |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| Transformer-XL | ``transfo-xl-wt103`` | 18-layer, 1024-hidden, 16-heads, 257M parameters |
|
||||
| | | English model trained on wikitext-103 |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| XLNet | ``xlnet-base-cased`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
|
||||
| | | XLNet English model |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlnet-large-cased`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
|
||||
| | | XLNet Large English model |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| XLM | ``xlm-mlm-en-2048`` | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | XLM English model |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-mlm-ende-1024`` | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | XLM English-German Multi-language model |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-mlm-enfr-1024`` | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | XLM English-French Multi-language model |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-mlm-enro-1024`` | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | XLM English-Romanian Multi-language model |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-mlm-xnli15-1024`` | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | XLM Model pre-trained with MLM on the `15 XNLI languages <https://github.com/facebookresearch/XNLI>`__. |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-mlm-tlm-xnli15-1024`` | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | XLM Model pre-trained with MLM + TLM on the `15 XNLI languages <https://github.com/facebookresearch/XNLI>`__. |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-clm-enfr-1024`` | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | XLM English model trained with CLM (Causal Language Modeling) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-clm-ende-1024`` | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | XLM English-German Multi-language model trained with CLM (Causal Language Modeling) |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| Architecture | Shortcut name | Details of the model |
|
||||
+===================+============================================================+=======================================================================================================================================+
|
||||
| BERT | ``bert-base-uncased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | Trained on lower-cased English text. |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-large-uncased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
|
||||
| | | | Trained on lower-cased English text. |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | Trained on cased English text. |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
|
||||
| | | | Trained on cased English text. |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-multilingual-uncased`` | | (Original, not recommended) 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | Trained on lower-cased text in the top 102 languages with the largest Wikipedias |
|
||||
| | | (see `details <https://github.com/google-research/bert/blob/master/multilingual.md>`__). |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-multilingual-cased`` | | (New, **recommended**) 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | Trained on cased text in the top 104 languages with the largest Wikipedias |
|
||||
| | | (see `details <https://github.com/google-research/bert/blob/master/multilingual.md>`__). |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-chinese`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | Trained on cased Chinese Simplified and Traditional text. |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-german-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | Trained on cased German text by Deepset.ai |
|
||||
| | | (see `details on deepset.ai website <https://deepset.ai/german-bert>`__). |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-large-uncased-whole-word-masking`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
|
||||
| | | | Trained on lower-cased English text using Whole-Word-Masking |
|
||||
| | | (see `details <https://github.com/google-research/bert/#bert>`__). |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-large-cased-whole-word-masking`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
|
||||
| | | | Trained on cased English text using Whole-Word-Masking |
|
||||
| | | (see `details <https://github.com/google-research/bert/#bert>`__). |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-large-uncased-whole-word-masking-finetuned-squad`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
|
||||
| | | | The ``bert-large-uncased-whole-word-masking`` model fine-tuned on SQuAD |
|
||||
| | | (see details of fine-tuning in the `example section <https://github.com/huggingface/pytorch-transformers/tree/master/examples>`__). |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-large-cased-whole-word-masking-finetuned-squad`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters |
|
||||
| | | | The ``bert-large-cased-whole-word-masking`` model fine-tuned on SQuAD |
|
||||
| | | (see `details of fine-tuning in the example section <https://huggingface.co/pytorch-transformers/examples.html>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-cased-finetuned-mrpc`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | The ``bert-base-cased`` model fine-tuned on MRPC |
|
||||
| | | (see `details of fine-tuning in the example section <https://huggingface.co/pytorch-transformers/examples.html>`__) |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| GPT | ``openai-gpt`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | OpenAI GPT English model |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| GPT-2 | ``gpt2`` | | 12-layer, 768-hidden, 12-heads, 117M parameters. |
|
||||
| | | | OpenAI GPT-2 English model |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``gpt2-medium`` | | 24-layer, 1024-hidden, 16-heads, 345M parameters. |
|
||||
| | | | OpenAI's Medium-sized GPT-2 English model |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| Transformer-XL | ``transfo-xl-wt103`` | | 18-layer, 1024-hidden, 16-heads, 257M parameters. |
|
||||
| | | | English model trained on wikitext-103 |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| XLNet | ``xlnet-base-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | XLNet English model |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlnet-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
|
||||
| | | | XLNet Large English model |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| XLM | ``xlm-mlm-en-2048`` | | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | | XLM English model |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-mlm-ende-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | | XLM English-German Multi-language model |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-mlm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | | XLM English-French Multi-language model |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-mlm-enro-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | | XLM English-Romanian Multi-language model |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-mlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | | XLM Model pre-trained with MLM on the `15 XNLI languages <https://github.com/facebookresearch/XNLI>`__. |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-mlm-tlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | | XLM Model pre-trained with MLM + TLM on the `15 XNLI languages <https://github.com/facebookresearch/XNLI>`__. |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-clm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | | XLM English model trained with CLM (Causal Language Modeling) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``xlm-clm-ende-1024`` | | 12-layer, 1024-hidden, 8-heads |
|
||||
| | | | XLM English-German Multi-language model trained with CLM (Causal Language Modeling) |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| RoBERTa | ``roberta-base`` | | 12-layer, 768-hidden, 12-heads, 125M parameters |
|
||||
| | | | RoBERTa using the BERT-base architecture |
|
||||
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/roberta>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``roberta-large`` | | 24-layer, 1024-hidden, 16-heads, 355M parameters |
|
||||
| | | | RoBERTa using the BERT-large architecture |
|
||||
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/roberta>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``roberta-large-mnli`` | | 24-layer, 1024-hidden, 16-heads, 355M parameters |
|
||||
| | | | ``roberta-large`` fine-tuned on `MNLI <http://www.nyu.edu/projects/bowman/multinli/>`__. |
|
||||
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/roberta>`__) |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
|
||||
.. <https://huggingface.co/pytorch-transformers/examples.html>`__
|
@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet)."""
|
||||
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa)."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
@ -33,6 +33,9 @@ from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
||||
BertForSequenceClassification, BertTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaTokenizer,
|
||||
XLMConfig, XLMForSequenceClassification,
|
||||
XLMTokenizer, XLNetConfig,
|
||||
XLNetForSequenceClassification,
|
||||
@ -45,12 +48,13 @@ from utils_glue import (compute_metrics, convert_examples_to_features,
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
||||
}
|
||||
|
||||
|
||||
@ -214,7 +218,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids
|
||||
'labels': batch[3]}
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
@ -264,14 +268,20 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||
label_list = processor.get_labels()
|
||||
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']:
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
|
||||
cls_token=tokenizer.cls_token,
|
||||
sep_token=tokenizer.sep_token,
|
||||
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
|
||||
sep_token=tokenizer.sep_token,
|
||||
sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
|
||||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0)
|
||||
pad_token=tokenizer.encoder[tokenizer.pad_token] if args.model_type in ['roberta'] else tokenizer.vocab[tokenizer.pad_token],
|
||||
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
|
||||
)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
@ -390,10 +390,16 @@ class WnliProcessor(DataProcessor):
|
||||
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
tokenizer, output_mode,
|
||||
cls_token_at_end=False, pad_on_left=False,
|
||||
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
||||
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
||||
cls_token_segment_id=1, pad_token_segment_id=0,
|
||||
cls_token_at_end=False,
|
||||
cls_token='[CLS]',
|
||||
cls_token_segment_id=1,
|
||||
sep_token='[SEP]',
|
||||
sep_token_extra=False,
|
||||
pad_on_left=False,
|
||||
pad_token=0,
|
||||
pad_token_segment_id=0,
|
||||
sequence_a_segment_id=0,
|
||||
sequence_b_segment_id=1,
|
||||
mask_padding_with_zero=True):
|
||||
""" Loads a data file into a list of `InputBatch`s
|
||||
`cls_token_at_end` define the location of the CLS token:
|
||||
@ -442,6 +448,9 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
# used as as the "sentence vector". Note that this only makes sense because
|
||||
# the entire model is fine-tuned.
|
||||
tokens = tokens_a + [sep_token]
|
||||
if sep_token_extra:
|
||||
# roberta uses an extra separator b/w pairs of sentences
|
||||
tokens += [sep_token]
|
||||
segment_ids = [sequence_a_segment_id] * len(tokens)
|
||||
|
||||
if tokens_b:
|
||||
|
@ -6,6 +6,8 @@ from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
|
||||
from .tokenization_gpt2 import GPT2Tokenizer
|
||||
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
|
||||
from .tokenization_xlm import XLMTokenizer
|
||||
from .tokenization_roberta import RobertaTokenizer
|
||||
|
||||
from .tokenization_utils import (PreTrainedTokenizer)
|
||||
|
||||
from .modeling_auto import (AutoConfig, AutoModel)
|
||||
@ -36,6 +38,8 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel,
|
||||
XLMWithLMHeadModel, XLMForSequenceClassification,
|
||||
XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
|
||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
|
||||
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
|
||||
|
||||
|
181
pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py
Normal file
181
pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py
Normal file
@ -0,0 +1,181 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert RoBERTa checkpoint."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||
from pytorch_transformers.modeling_bert import (BertConfig, BertEncoder,
|
||||
BertIntermediate, BertLayer,
|
||||
BertModel, BertOutput,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput)
|
||||
from pytorch_transformers.modeling_roberta import (RobertaEmbeddings,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaModel)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_TEXT = 'Hello world! cécé herlolip'
|
||||
|
||||
|
||||
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head):
|
||||
"""
|
||||
Copy/paste/tweak roberta's weights to our BERT structure.
|
||||
"""
|
||||
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
|
||||
roberta.eval() # disable dropout
|
||||
config = BertConfig(
|
||||
vocab_size_or_config_json_file=50265,
|
||||
hidden_size=roberta.args.encoder_embed_dim,
|
||||
num_hidden_layers=roberta.args.encoder_layers,
|
||||
num_attention_heads=roberta.args.encoder_attention_heads,
|
||||
intermediate_size=roberta.args.encoder_ffn_embed_dim,
|
||||
max_position_embeddings=514,
|
||||
type_vocab_size=1,
|
||||
)
|
||||
if classification_head:
|
||||
config.num_labels = roberta.args.num_classes
|
||||
print("Our BERT config:", config)
|
||||
|
||||
model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config)
|
||||
model.eval()
|
||||
|
||||
# Now let's copy all the weights.
|
||||
# Embeddings
|
||||
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
|
||||
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
|
||||
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
|
||||
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them.
|
||||
model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
|
||||
model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
|
||||
model.roberta.embeddings.LayerNorm.variance_epsilon = roberta_sent_encoder.emb_layer_norm.eps
|
||||
|
||||
for i in range(config.num_hidden_layers):
|
||||
# Encoder: start of layer
|
||||
layer: BertLayer = model.roberta.encoder.layer[i]
|
||||
roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]
|
||||
|
||||
### self attention
|
||||
self_attn: BertSelfAttention = layer.attention.self
|
||||
assert(
|
||||
roberta_layer.self_attn.in_proj_weight.shape == torch.Size((3 * config.hidden_size, config.hidden_size))
|
||||
)
|
||||
# we use three distinct linear layers so we split the source layer here.
|
||||
self_attn.query.weight.data = roberta_layer.self_attn.in_proj_weight[:config.hidden_size, :]
|
||||
self_attn.query.bias.data = roberta_layer.self_attn.in_proj_bias[:config.hidden_size]
|
||||
self_attn.key.weight.data = roberta_layer.self_attn.in_proj_weight[config.hidden_size:2*config.hidden_size, :]
|
||||
self_attn.key.bias.data = roberta_layer.self_attn.in_proj_bias[config.hidden_size:2*config.hidden_size]
|
||||
self_attn.value.weight.data = roberta_layer.self_attn.in_proj_weight[2*config.hidden_size:, :]
|
||||
self_attn.value.bias.data = roberta_layer.self_attn.in_proj_bias[2*config.hidden_size:]
|
||||
|
||||
### self-attention output
|
||||
self_output: BertSelfOutput = layer.attention.output
|
||||
assert(
|
||||
self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
|
||||
)
|
||||
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
|
||||
self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
|
||||
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
|
||||
self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
|
||||
self_output.LayerNorm.variance_epsilon = roberta_layer.self_attn_layer_norm.eps
|
||||
|
||||
### intermediate
|
||||
intermediate: BertIntermediate = layer.intermediate
|
||||
assert(
|
||||
intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
|
||||
)
|
||||
intermediate.dense.weight = roberta_layer.fc1.weight
|
||||
intermediate.dense.bias = roberta_layer.fc1.bias
|
||||
|
||||
### output
|
||||
bert_output: BertOutput = layer.output
|
||||
assert(
|
||||
bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
|
||||
)
|
||||
bert_output.dense.weight = roberta_layer.fc2.weight
|
||||
bert_output.dense.bias = roberta_layer.fc2.bias
|
||||
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
|
||||
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
|
||||
bert_output.LayerNorm.variance_epsilon = roberta_layer.final_layer_norm.eps
|
||||
#### end of layer
|
||||
|
||||
if classification_head:
|
||||
model.classifier.dense.weight = roberta.model.classification_heads['mnli'].dense.weight
|
||||
model.classifier.dense.bias = roberta.model.classification_heads['mnli'].dense.bias
|
||||
model.classifier.out_proj.weight = roberta.model.classification_heads['mnli'].out_proj.weight
|
||||
model.classifier.out_proj.bias = roberta.model.classification_heads['mnli'].out_proj.bias
|
||||
else:
|
||||
# LM Head
|
||||
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
|
||||
model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias
|
||||
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
|
||||
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
|
||||
model.lm_head.layer_norm.variance_epsilon = roberta.model.decoder.lm_head.layer_norm.eps
|
||||
model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight
|
||||
model.lm_head.bias = roberta.model.decoder.lm_head.bias
|
||||
|
||||
# Let's check that we get the same results.
|
||||
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
||||
|
||||
our_output = model(input_ids)[0]
|
||||
if classification_head:
|
||||
their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids))
|
||||
else:
|
||||
their_output = roberta.model(input_ids)[0]
|
||||
print(our_output.shape, their_output.shape)
|
||||
success = torch.allclose(our_output, their_output, atol=1e-3)
|
||||
print(
|
||||
"Do both models output the same tensors?",
|
||||
"🔥" if success else "💩"
|
||||
)
|
||||
if not success:
|
||||
raise Exception("Something went wRoNg")
|
||||
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--roberta_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path the official PyTorch dump.")
|
||||
parser.add_argument("--pytorch_dump_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
parser.add_argument("--classification_head",
|
||||
action = "store_true",
|
||||
help = "Whether to convert a final classification head.")
|
||||
args = parser.parse_args()
|
||||
convert_roberta_checkpoint_to_pytorch(
|
||||
args.roberta_checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.classification_head
|
||||
)
|
||||
|
349
pytorch_transformers/modeling_roberta.py
Normal file
349
pytorch_transformers/modeling_roberta.py
Normal file
@ -0,0 +1,349 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch RoBERTa model. """
|
||||
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from pytorch_transformers.modeling_bert import (BertConfig, BertEmbeddings,
|
||||
BertLayerNorm, BertModel,
|
||||
BertPreTrainedModel, gelu)
|
||||
|
||||
from pytorch_transformers.modeling_utils import add_start_docstrings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-pytorch_model.bin",
|
||||
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-pytorch_model.bin",
|
||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin",
|
||||
}
|
||||
|
||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
|
||||
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
|
||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
|
||||
}
|
||||
|
||||
|
||||
class RobertaEmbeddings(BertEmbeddings):
|
||||
"""
|
||||
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(RobertaEmbeddings, self).__init__(config)
|
||||
self.padding_idx = 1
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
||||
seq_length = input_ids.size(1)
|
||||
if position_ids is None:
|
||||
# Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
||||
# cf. fairseq's `utils.make_positions`
|
||||
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
return super(RobertaEmbeddings, self).forward(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
|
||||
|
||||
|
||||
class RobertaConfig(BertConfig):
|
||||
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
|
||||
ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in
|
||||
`RoBERTa: A Robustly Optimized BERT Pretraining Approach`_
|
||||
by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer,
|
||||
Veselin Stoyanov. It is based on Google's BERT model released in 2018.
|
||||
|
||||
It builds on BERT and modifies key hyperparameters, removing the next-sentence pretraining
|
||||
objective and training with much larger mini-batches and learning rates.
|
||||
|
||||
This implementation is the same as BertModel with a tiny embeddings tweak as well as a setup for Roberta pretrained
|
||||
models.
|
||||
|
||||
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
|
||||
refer to the PyTorch documentation for all matter related to general usage and behavior.
|
||||
|
||||
.. _`RoBERTa: A Robustly Optimized BERT Pretraining Approach`:
|
||||
https://arxiv.org/abs/1907.11692
|
||||
|
||||
.. _`torch.nn.Module`:
|
||||
https://pytorch.org/docs/stable/nn.html#module
|
||||
|
||||
Parameters:
|
||||
config (:class:`~pytorch_transformers.RobertaConfig`): Model configuration class with all the parameters of the
|
||||
model.
|
||||
"""
|
||||
|
||||
ROBERTA_INPUTS_DOCSTRING = r"""
|
||||
Inputs:
|
||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
To match pre-training, RoBERTa input sequence should be formatted with [CLS] and [SEP] tokens as follows:
|
||||
|
||||
(a) For sequence pairs:
|
||||
|
||||
``tokens: [CLS] is this jack ##son ##ville ? [SEP][SEP] no it is not . [SEP]``
|
||||
|
||||
(b) For single sequences:
|
||||
|
||||
``tokens: [CLS] the dog is hairy . [SEP]``
|
||||
|
||||
Fully encoded sequences or sequence pairs can be obtained using the RobertaTokenizer.encode function with
|
||||
the ``add_special_tokens`` parameter set to ``True``.
|
||||
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1[``.
|
||||
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare RoBERTa Model transformer outputing raw hidden-states without any specific head on top.",
|
||||
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
||||
class RobertaModel(BertModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||
layer weights are trained from the next sentence prediction (classification)
|
||||
objective during Bert pretraining. This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
||||
model = RobertaModel.from_pretrained('roberta-base')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super(RobertaModel, self).__init__(config)
|
||||
|
||||
self.embeddings = RobertaEmbeddings(config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
|
||||
if input_ids[:, 0].sum().item() != 0:
|
||||
logger.warning("A sequence with no special tokens has been passed to the RoBERTa model. "
|
||||
"This model requires special tokens in order to work. "
|
||||
"Please specify add_special_tokens=True in your encoding.")
|
||||
return super(RobertaModel, self).forward(input_ids, token_type_ids, attention_mask, position_ids, head_mask)
|
||||
|
||||
|
||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
|
||||
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
||||
class RobertaForMaskedLM(BertPreTrainedModel):
|
||||
r"""
|
||||
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for computing the masked language modeling loss.
|
||||
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Masked language modeling loss.
|
||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
||||
model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, masked_lm_labels=input_ids)
|
||||
loss, prediction_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super(RobertaForMaskedLM, self).__init__(config)
|
||||
|
||||
self.roberta = RobertaModel(config)
|
||||
self.lm_head = RobertaLMHead(config)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head.decoder, self.roberta.embeddings.word_embeddings)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, position_ids=None,
|
||||
head_mask=None):
|
||||
outputs = self.roberta(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask, head_mask=head_mask)
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||
|
||||
if masked_lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||
outputs = (masked_lm_loss,) + outputs
|
||||
|
||||
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
class RobertaLMHead(nn.Module):
|
||||
"""Roberta Head for masked language modeling."""
|
||||
|
||||
def __init__(self, config):
|
||||
super(RobertaLMHead, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = self.dense(features)
|
||||
x = gelu(x)
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# project back to size of vocabulary with bias
|
||||
x = self.decoder(x) + self.bias
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@add_start_docstrings("""RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer
|
||||
on top of the pooled output) e.g. for GLUE tasks. """,
|
||||
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
||||
class RobertaForSequenceClassification(BertPreTrainedModel):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in ``[0, ..., config.num_labels]``.
|
||||
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = RoertaTokenizer.from_pretrained('roberta-base')
|
||||
model = RobertaForSequenceClassification.from_pretrained('roberta-base')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, logits = outputs[:2]
|
||||
|
||||
"""
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super(RobertaForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.roberta = RobertaModel(config)
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
||||
position_ids=None, head_mask=None):
|
||||
outputs = self.roberta(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask, head_mask=head_mask)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:]
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
|
||||
class RobertaClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
def __init__(self, config):
|
||||
super(RobertaClassificationHead, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
||||
x = self.dropout(x)
|
||||
x = self.dense(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.dropout(x)
|
||||
x = self.out_proj(x)
|
||||
return x
|
242
pytorch_transformers/tests/modeling_roberta_test.py
Normal file
242
pytorch_transformers/tests/modeling_roberta_test.py
Normal file
@ -0,0 +1,242 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification)
|
||||
from pytorch_transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
|
||||
|
||||
|
||||
class RobertaModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (RobertaForMaskedLM, RobertaModel)
|
||||
|
||||
class RobertaModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = RobertaConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
|
||||
def create_and_check_roberta_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels,
|
||||
token_labels, choice_labels):
|
||||
model = RobertaModel(config=config)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_roberta_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels,
|
||||
token_labels, choice_labels):
|
||||
model = RobertaForMaskedLM(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, token_type_ids, input_mask, token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, token_type_ids, input_mask,
|
||||
sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RobertaModelTest.RobertaModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=RobertaConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_roberta_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_roberta_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_roberta_for_masked_lm(*config_and_inputs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = RobertaModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
||||
class RobertaModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
||||
|
||||
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 11, 50265))
|
||||
self.assertEqual(
|
||||
output.shape,
|
||||
expected_shape
|
||||
)
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = torch.Tensor(
|
||||
[[[33.8843, -4.3107, 22.7779],
|
||||
[ 4.6533, -2.8099, 13.6252],
|
||||
[ 1.8222, -3.6898, 8.8600]]]
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
|
||||
)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_inference_no_head(self):
|
||||
model = RobertaModel.from_pretrained('roberta-base')
|
||||
|
||||
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = torch.Tensor(
|
||||
[[[-0.0231, 0.0782, 0.0074],
|
||||
[-0.1854, 0.0539, -0.0174],
|
||||
[ 0.0548, 0.0799, 0.1687]]]
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
|
||||
)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_inference_classification_head(self):
|
||||
model = RobertaForSequenceClassification.from_pretrained('roberta-large-mnli')
|
||||
|
||||
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 3))
|
||||
self.assertEqual(
|
||||
output.shape,
|
||||
expected_shape
|
||||
)
|
||||
expected_tensor = torch.Tensor([[-0.9469, 0.3913, 0.5118]])
|
||||
self.assertTrue(
|
||||
torch.allclose(output, expected_tensor, atol=1e-3)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -125,6 +125,17 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
self.assertFalse(_is_punctuation(u"A"))
|
||||
self.assertFalse(_is_punctuation(u" "))
|
||||
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
text = tokenizer.encode("sequence builders")
|
||||
text_2 = tokenizer.encode("multi-sequence build")
|
||||
|
||||
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
||||
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
||||
|
||||
assert encoded_sentence == [101] + text + [102]
|
||||
assert encoded_pair == [101] + text + [102] + text_2 + [102]
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
95
pytorch_transformers/tests/tokenization_roberta_test.py
Normal file
95
pytorch_transformers/tests/tokenization_roberta_test.py
Normal file
@ -0,0 +1,95 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
|
||||
|
||||
class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
tokenizer_class = RobertaTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super(RobertaTokenizationTest, self).setUp()
|
||||
|
||||
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||
"lo", "low", "er",
|
||||
"low", "lowest", "newer", "wider", "<unk>"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||
with open(self.vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self):
|
||||
return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower<unk>newer"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_bpe_tokens = [13, 12, 17]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
def roberta_dict_integration_testing(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.encode('Hello world!'),
|
||||
[0, 31414, 232, 328, 2]
|
||||
)
|
||||
self.assertListEqual(
|
||||
tokenizer.encode('Hello world! cécé herlolip 418'),
|
||||
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
|
||||
)
|
||||
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
||||
|
||||
text = tokenizer.encode("sequence builders")
|
||||
text_2 = tokenizer.encode("multi-sequence build")
|
||||
|
||||
encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
|
||||
encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)
|
||||
|
||||
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
||||
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
||||
|
||||
assert encoded_sentence == encoded_text_from_decode
|
||||
assert encoded_pair == encoded_pair_from_decode
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -66,6 +66,17 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048")
|
||||
|
||||
text = tokenizer.encode("sequence builders")
|
||||
text_2 = tokenizer.encode("multi-sequence build")
|
||||
|
||||
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
||||
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
||||
|
||||
assert encoded_sentence == [1] + text + [1]
|
||||
assert encoded_pair == [1] + text + [1] + text_2 + [1]
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -89,6 +89,18 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
|
||||
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
|
||||
|
||||
text = tokenizer.encode("sequence builders")
|
||||
text_2 = tokenizer.encode("multi-sequence build")
|
||||
|
||||
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
||||
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
||||
|
||||
assert encoded_sentence == text + [4, 3]
|
||||
assert encoded_pair == text + [4] + text_2 + [4, 3]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -166,6 +166,22 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||
out_string = ' '.join(tokens).replace(' ##', '').strip()
|
||||
return out_string
|
||||
|
||||
def add_special_tokens_single_sentence(self, token_ids):
|
||||
"""
|
||||
Adds special tokens to the a sequence for sequence classification tasks.
|
||||
A BERT sequence has the following format: [CLS] X [SEP]
|
||||
"""
|
||||
return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)]
|
||||
|
||||
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
|
||||
"""
|
||||
Adds special tokens to a sequence pair for sequence classification tasks.
|
||||
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
|
||||
"""
|
||||
sep = [self._convert_token_to_id(self.sep_token)]
|
||||
cls = [self._convert_token_to_id(self.cls_token)]
|
||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||
|
||||
def save_vocabulary(self, vocab_path):
|
||||
"""Save the tokenizer vocabulary to a directory or file."""
|
||||
index = 0
|
||||
|
201
pytorch_transformers/tokenization_roberta.py
Normal file
201
pytorch_transformers/tokenization_roberta.py
Normal file
@ -0,0 +1,201 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for RoBERTa."""
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import regex as re
|
||||
from io import open
|
||||
|
||||
from .tokenization_gpt2 import bytes_to_unicode, get_pairs
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
try:
|
||||
from functools import lru_cache
|
||||
except ImportError:
|
||||
# Just a dummy decorator to get the checks to run on python2
|
||||
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
|
||||
def lru_cache():
|
||||
return lambda func: func
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
'vocab_file': 'vocab.json',
|
||||
'merges_file': 'merges.txt',
|
||||
}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'vocab_file':
|
||||
{
|
||||
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
|
||||
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
|
||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json",
|
||||
},
|
||||
'merges_file':
|
||||
{
|
||||
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
|
||||
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
|
||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt",
|
||||
},
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'roberta-base': 512,
|
||||
'roberta-large': 512,
|
||||
'roberta-large-mnli': 512,
|
||||
}
|
||||
|
||||
|
||||
class RobertaTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities: Byte-level BPE
|
||||
"""
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, merges_file, errors='replace', bos_token="<s>", eos_token="</s>", sep_token="</s>",
|
||||
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>', **kwargs):
|
||||
super(RobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
|
||||
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
|
||||
mask_token=mask_token, **kwargs)
|
||||
|
||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.encoder)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def _tokenize(self, text):
|
||||
""" Tokenize a string. """
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
if sys.version_info[0] == 2:
|
||||
token = ''.join(self.byte_encoder[ord(b)] for b in token)
|
||||
else:
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||||
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||||
return self.decoder.get(index)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
text = ''.join(tokens)
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
return text
|
||||
|
||||
def add_special_tokens_single_sentence(self, token_ids):
|
||||
"""
|
||||
Adds special tokens to a sequence for sequence classification tasks.
|
||||
A RoBERTa sequence has the following format: [CLS] X [SEP]
|
||||
"""
|
||||
return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)]
|
||||
|
||||
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
|
||||
"""
|
||||
Adds special tokens to a sequence pair for sequence classification tasks.
|
||||
A RoBERTa sequence pair has the following format: [CLS] A [SEP][SEP] B [SEP]
|
||||
"""
|
||||
sep = [self._convert_token_to_id(self.sep_token)]
|
||||
cls = [self._convert_token_to_id(self.cls_token)]
|
||||
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
|
||||
|
||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write(u'#version: 0.2\n')
|
||||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(merge_file))
|
||||
index = token_index
|
||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||
index += 1
|
||||
|
||||
return vocab_file, merge_file
|
@ -180,9 +180,10 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *inputs, **kwargs):
|
||||
r""" Instantiate a :class:`~pytorch_transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer.
|
||||
r"""
|
||||
Instantiate a :class:`~pytorch_transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
pretrained_model_name_or_path: either:
|
||||
|
||||
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||
@ -383,14 +384,15 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
|
||||
def add_tokens(self, new_tokens):
|
||||
""" Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
||||
"""
|
||||
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
||||
vocabulary, they are added to it with indices starting from length of the current vocabulary.
|
||||
|
||||
Parameters:
|
||||
new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
|
||||
Args:
|
||||
new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
|
||||
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary.
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary.
|
||||
|
||||
Examples::
|
||||
|
||||
@ -422,17 +424,20 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict):
|
||||
""" Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
|
||||
to class attributes. If special tokens are NOT in the vocabulary, they are added
|
||||
to it (indexed starting from the last index of the current vocabulary).
|
||||
"""
|
||||
Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
|
||||
to class attributes. If special tokens are NOT in the vocabulary, they are added
|
||||
to it (indexed starting from the last index of the current vocabulary).
|
||||
|
||||
Parameters:
|
||||
special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``].
|
||||
|
||||
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
|
||||
Args:
|
||||
special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes:
|
||||
[``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,
|
||||
``additional_special_tokens``].
|
||||
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary.
|
||||
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
|
||||
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary.
|
||||
|
||||
Examples::
|
||||
|
||||
@ -519,14 +524,37 @@ class PreTrainedTokenizer(object):
|
||||
def _convert_token_to_id(self, token):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def encode(self, text):
|
||||
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
|
||||
Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
|
||||
def encode(self, text, text_pair=None, add_special_tokens=False):
|
||||
"""
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
|
||||
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
|
||||
|
||||
Args:
|
||||
text: The first sequence to be encoded.
|
||||
text_pair: Optional second sequence to be encoded.
|
||||
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
|
||||
to their model.
|
||||
"""
|
||||
if text_pair is None:
|
||||
if add_special_tokens:
|
||||
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text)))
|
||||
else:
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)]
|
||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair)]
|
||||
|
||||
if add_special_tokens:
|
||||
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
|
||||
else:
|
||||
return first_sentence_tokens, second_sentence_tokens
|
||||
|
||||
def add_special_tokens_single_sentence(self, token_ids):
|
||||
raise NotImplementedError
|
||||
|
||||
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
""" Converts a single index or a sequence of indices (integers) in a token "
|
||||
@ -561,16 +589,28 @@ class PreTrainedTokenizer(object):
|
||||
return ' '.join(self.convert_ids_to_tokens(tokens))
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
||||
with options to remove special tokens and clean up tokenization spaces.
|
||||
|
||||
"""
|
||||
Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
||||
with options to remove special tokens and clean up tokenization spaces.
|
||||
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
|
||||
"""
|
||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
text = self.convert_tokens_to_string(filtered_tokens)
|
||||
if clean_up_tokenization_spaces:
|
||||
text = self.clean_up_tokenization(text)
|
||||
return text
|
||||
|
||||
if self.sep_token is not None and self.sep_token in text:
|
||||
text = text.replace(self.cls_token, self.sep_token)
|
||||
split_text = list(filter(lambda sentence: len(sentence) > 0, text.split(self.sep_token)))
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = [self.clean_up_tokenization(text) for text in split_text]
|
||||
return clean_text
|
||||
else:
|
||||
return split_text
|
||||
else:
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = self.clean_up_tokenization(text)
|
||||
return clean_text
|
||||
else:
|
||||
return text
|
||||
|
||||
@property
|
||||
def special_tokens_map(self):
|
||||
@ -602,7 +642,7 @@ class PreTrainedTokenizer(object):
|
||||
class attributes (cls_token, unk_token...).
|
||||
"""
|
||||
all_toks = self.all_special_tokens
|
||||
all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks)
|
||||
all_ids = list(self._convert_token_to_id(t) for t in all_toks)
|
||||
return all_ids
|
||||
|
||||
@staticmethod
|
||||
|
@ -214,6 +214,22 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||||
return out_string
|
||||
|
||||
def add_special_tokens_single_sentence(self, token_ids):
|
||||
"""
|
||||
Adds special tokens to a sequence for sequence classification tasks.
|
||||
An XLM sequence has the following format: [CLS] X [SEP]
|
||||
"""
|
||||
return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)]
|
||||
|
||||
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
|
||||
"""
|
||||
Adds special tokens to a sequence pair for sequence classification tasks.
|
||||
An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP]
|
||||
"""
|
||||
sep = [self._convert_token_to_id(self.sep_token)]
|
||||
cls = [self._convert_token_to_id(self.cls_token)]
|
||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||||
if not os.path.isdir(save_directory):
|
||||
|
@ -177,6 +177,24 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
|
||||
return out_string
|
||||
|
||||
def add_special_tokens_single_sentence(self, token_ids):
|
||||
"""
|
||||
Adds special tokens to a sequence pair for sequence classification tasks.
|
||||
An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS]
|
||||
"""
|
||||
sep = [self._convert_token_to_id(self.sep_token)]
|
||||
cls = [self._convert_token_to_id(self.cls_token)]
|
||||
return token_ids + sep + cls
|
||||
|
||||
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
|
||||
"""
|
||||
Adds special tokens to a sequence for sequence classification tasks.
|
||||
An XLNet sequence has the following format: X [SEP][CLS]
|
||||
"""
|
||||
sep = [self._convert_token_to_id(self.sep_token)]
|
||||
cls = [self._convert_token_to_id(self.cls_token)]
|
||||
return token_ids_0 + sep + token_ids_1 + sep + cls
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
|
||||
to a directory.
|
||||
|
Loading…
Reference in New Issue
Block a user