Merge pull request #964 from huggingface/RoBERTa

RoBERTa: model conversion, inference, tests 🔥
This commit is contained in:
Lysandre Debut 2019-08-15 11:11:10 -04:00 committed by GitHub
commit 88efc65bac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1395 additions and 130 deletions

View File

@ -12,6 +12,7 @@ The library currently contains PyTorch implementations, pre-trained model weight
4. **[Transformer-XL](https://github.com/kimiyoung/transformer-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
5. **[XLNet](https://github.com/zihangdai/xlnet/)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
6. **[XLM](https://github.com/facebookresearch/XLM/)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du et al.
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/pytorch-transformers/examples.html).

View File

@ -47,3 +47,4 @@ The library currently contains PyTorch implementations, pre-trained model weight
model_doc/gpt2
model_doc/xlm
model_doc/xlnet
model_doc/roberta

View File

@ -0,0 +1,36 @@
RoBERTa
----------------------------------------------------
``RobertaConfig``
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.RobertaConfig
:members:
``RobertaTokenizer``
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.RobertaTokenizer
:members:
``RobertaModel``
~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.RobertaModel
:members:
``RobertaForMaskedLM``
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.RobertaForMaskedLM
:members:
``RobertaForSequenceClassification``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.RobertaForSequenceClassification
:members:

View File

@ -4,97 +4,109 @@ Pretrained models
Here is the full list of the currently provided pretrained models together with a short presentation of each model.
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| Architecture | Shortcut name | Details of the model |
+===================+============================================================+===========================================================================================================================+
| BERT | ``bert-base-uncased`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
| | | Trained on lower-cased English text |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``bert-large-uncased`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
| | | Trained on lower-cased English text |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-cased`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
| | | Trained on cased English text |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``bert-large-cased`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
| | | Trained on cased English text |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-multilingual-uncased`` | (Original, not recommended) 12-layer, 768-hidden, 12-heads, 110M parameters |
| | | Trained on lower-cased text in the top 102 languages with the largest Wikipedias |
| | | (see `details <https://github.com/google-research/bert/blob/master/multilingual.md>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-multilingual-cased`` | (New, **recommended**) 12-layer, 768-hidden, 12-heads, 110M parameters |
| | | Trained on cased text in the top 104 languages with the largest Wikipedias |
| | | (see `details <https://github.com/google-research/bert/blob/master/multilingual.md>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-chinese`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
| | | Trained on cased Chinese Simplified and Traditional text |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-german-cased`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
| | | Trained on cased German text by Deepset.ai |
| | | (see `details on deepset.ai website <https://deepset.ai/german-bert>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``bert-large-uncased-whole-word-masking`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
| | | Trained on lower-cased English text using Whole-Word-Masking |
| | | (see `details <https://github.com/google-research/bert/#bert>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``bert-large-cased-whole-word-masking`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
| | | Trained on cased English text using Whole-Word-Masking |
| | | (see `details <https://github.com/google-research/bert/#bert>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``bert-large-uncased-whole-word-masking-finetuned-squad`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
| | | The ``bert-large-uncased-whole-word-masking`` model fine-tuned on SQuAD (see details of fine-tuning in the |
| | | `example section <https://github.com/huggingface/pytorch-transformers/tree/master/examples>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``bert-large-cased-whole-word-masking-finetuned-squad`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
| | | The ``bert-large-cased-whole-word-masking`` model fine-tuned on SQuAD |
| | | (see `details of fine-tuning in the example section <https://huggingface.co/pytorch-transformers/examples.html>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-cased-finetuned-mrpc`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
| | | The ``bert-base-cased`` model fine-tuned on MRPC |
| | | (see `details of fine-tuning in the example section <https://huggingface.co/pytorch-transformers/examples.html>`__) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| GPT | ``openai-gpt`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
| | | OpenAI GPT English model |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| GPT-2 | ``gpt2`` | 12-layer, 768-hidden, 12-heads, 117M parameters |
| | | OpenAI GPT-2 English model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``gpt2-medium`` | 24-layer, 1024-hidden, 16-heads, 345M parameters |
| | | OpenAI's Medium-sized GPT-2 English model |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| Transformer-XL | ``transfo-xl-wt103`` | 18-layer, 1024-hidden, 16-heads, 257M parameters |
| | | English model trained on wikitext-103 |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| XLNet | ``xlnet-base-cased`` | 12-layer, 768-hidden, 12-heads, 110M parameters |
| | | XLNet English model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``xlnet-large-cased`` | 24-layer, 1024-hidden, 16-heads, 340M parameters |
| | | XLNet Large English model |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| XLM | ``xlm-mlm-en-2048`` | 12-layer, 1024-hidden, 8-heads |
| | | XLM English model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-ende-1024`` | 12-layer, 1024-hidden, 8-heads |
| | | XLM English-German Multi-language model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-enfr-1024`` | 12-layer, 1024-hidden, 8-heads |
| | | XLM English-French Multi-language model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-enro-1024`` | 12-layer, 1024-hidden, 8-heads |
| | | XLM English-Romanian Multi-language model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-xnli15-1024`` | 12-layer, 1024-hidden, 8-heads |
| | | XLM Model pre-trained with MLM on the `15 XNLI languages <https://github.com/facebookresearch/XNLI>`__. |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-tlm-xnli15-1024`` | 12-layer, 1024-hidden, 8-heads |
| | | XLM Model pre-trained with MLM + TLM on the `15 XNLI languages <https://github.com/facebookresearch/XNLI>`__. |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-clm-enfr-1024`` | 12-layer, 1024-hidden, 8-heads |
| | | XLM English model trained with CLM (Causal Language Modeling) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-clm-ende-1024`` | 12-layer, 1024-hidden, 8-heads |
| | | XLM English-German Multi-language model trained with CLM (Causal Language Modeling) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------+
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Architecture | Shortcut name | Details of the model |
+===================+============================================================+=======================================================================================================================================+
| BERT | ``bert-base-uncased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on lower-cased English text. |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-large-uncased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
| | | | Trained on lower-cased English text. |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on cased English text. |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
| | | | Trained on cased English text. |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-multilingual-uncased`` | | (Original, not recommended) 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on lower-cased text in the top 102 languages with the largest Wikipedias |
| | | (see `details <https://github.com/google-research/bert/blob/master/multilingual.md>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-multilingual-cased`` | | (New, **recommended**) 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on cased text in the top 104 languages with the largest Wikipedias |
| | | (see `details <https://github.com/google-research/bert/blob/master/multilingual.md>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-chinese`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on cased Chinese Simplified and Traditional text. |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-german-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | Trained on cased German text by Deepset.ai |
| | | (see `details on deepset.ai website <https://deepset.ai/german-bert>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-large-uncased-whole-word-masking`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
| | | | Trained on lower-cased English text using Whole-Word-Masking |
| | | (see `details <https://github.com/google-research/bert/#bert>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-large-cased-whole-word-masking`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
| | | | Trained on cased English text using Whole-Word-Masking |
| | | (see `details <https://github.com/google-research/bert/#bert>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-large-uncased-whole-word-masking-finetuned-squad`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
| | | | The ``bert-large-uncased-whole-word-masking`` model fine-tuned on SQuAD |
| | | (see details of fine-tuning in the `example section <https://github.com/huggingface/pytorch-transformers/tree/master/examples>`__). |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-large-cased-whole-word-masking-finetuned-squad`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters |
| | | | The ``bert-large-cased-whole-word-masking`` model fine-tuned on SQuAD |
| | | (see `details of fine-tuning in the example section <https://huggingface.co/pytorch-transformers/examples.html>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``bert-base-cased-finetuned-mrpc`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | The ``bert-base-cased`` model fine-tuned on MRPC |
| | | (see `details of fine-tuning in the example section <https://huggingface.co/pytorch-transformers/examples.html>`__) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| GPT | ``openai-gpt`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | OpenAI GPT English model |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| GPT-2 | ``gpt2`` | | 12-layer, 768-hidden, 12-heads, 117M parameters. |
| | | | OpenAI GPT-2 English model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``gpt2-medium`` | | 24-layer, 1024-hidden, 16-heads, 345M parameters. |
| | | | OpenAI's Medium-sized GPT-2 English model |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Transformer-XL | ``transfo-xl-wt103`` | | 18-layer, 1024-hidden, 16-heads, 257M parameters. |
| | | | English model trained on wikitext-103 |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| XLNet | ``xlnet-base-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
| | | | XLNet English model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``xlnet-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 340M parameters. |
| | | | XLNet Large English model |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| XLM | ``xlm-mlm-en-2048`` | | 12-layer, 1024-hidden, 8-heads |
| | | | XLM English model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-ende-1024`` | | 12-layer, 1024-hidden, 8-heads |
| | | | XLM English-German Multi-language model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads |
| | | | XLM English-French Multi-language model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-enro-1024`` | | 12-layer, 1024-hidden, 8-heads |
| | | | XLM English-Romanian Multi-language model |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads |
| | | | XLM Model pre-trained with MLM on the `15 XNLI languages <https://github.com/facebookresearch/XNLI>`__. |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-mlm-tlm-xnli15-1024`` | | 12-layer, 1024-hidden, 8-heads |
| | | | XLM Model pre-trained with MLM + TLM on the `15 XNLI languages <https://github.com/facebookresearch/XNLI>`__. |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-clm-enfr-1024`` | | 12-layer, 1024-hidden, 8-heads |
| | | | XLM English model trained with CLM (Causal Language Modeling) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``xlm-clm-ende-1024`` | | 12-layer, 1024-hidden, 8-heads |
| | | | XLM English-German Multi-language model trained with CLM (Causal Language Modeling) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| RoBERTa | ``roberta-base`` | | 12-layer, 768-hidden, 12-heads, 125M parameters |
| | | | RoBERTa using the BERT-base architecture |
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/roberta>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``roberta-large`` | | 24-layer, 1024-hidden, 16-heads, 355M parameters |
| | | | RoBERTa using the BERT-large architecture |
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/roberta>`__) |
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| | ``roberta-large-mnli`` | | 24-layer, 1024-hidden, 16-heads, 355M parameters |
| | | | ``roberta-large`` fine-tuned on `MNLI <http://www.nyu.edu/projects/bowman/multinli/>`__. |
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/roberta>`__) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
.. <https://huggingface.co/pytorch-transformers/examples.html>`__

View File

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet)."""
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa)."""
from __future__ import absolute_import, division, print_function
@ -33,6 +33,9 @@ from tqdm import tqdm, trange
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
BertForSequenceClassification, BertTokenizer,
RobertaConfig,
RobertaForSequenceClassification,
RobertaTokenizer,
XLMConfig, XLMForSequenceClassification,
XLMTokenizer, XLNetConfig,
XLNetForSequenceClassification,
@ -45,12 +48,13 @@ from utils_glue import (compute_metrics, convert_examples_to_features,
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig)), ())
MODEL_CLASSES = {
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
}
@ -214,7 +218,7 @@ def evaluate(args, model, tokenizer, prefix=""):
with torch.no_grad():
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM and RoBERTa don't use segment_ids
'labels': batch[3]}
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]
@ -264,14 +268,20 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
else:
logger.info("Creating features from dataset file at %s", args.data_dir)
label_list = processor.get_labels()
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']:
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1]
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
cls_token=tokenizer.cls_token,
sep_token=tokenizer.sep_token,
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
sep_token=tokenizer.sep_token,
sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0)
pad_token=tokenizer.encoder[tokenizer.pad_token] if args.model_type in ['roberta'] else tokenizer.vocab[tokenizer.pad_token],
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
)
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)

View File

@ -390,10 +390,16 @@ class WnliProcessor(DataProcessor):
def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer, output_mode,
cls_token_at_end=False, pad_on_left=False,
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
sequence_a_segment_id=0, sequence_b_segment_id=1,
cls_token_segment_id=1, pad_token_segment_id=0,
cls_token_at_end=False,
cls_token='[CLS]',
cls_token_segment_id=1,
sep_token='[SEP]',
sep_token_extra=False,
pad_on_left=False,
pad_token=0,
pad_token_segment_id=0,
sequence_a_segment_id=0,
sequence_b_segment_id=1,
mask_padding_with_zero=True):
""" Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token:
@ -442,6 +448,9 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = tokens_a + [sep_token]
if sep_token_extra:
# roberta uses an extra separator b/w pairs of sentences
tokens += [sep_token]
segment_ids = [sequence_a_segment_id] * len(tokens)
if tokens_b:

View File

@ -6,6 +6,8 @@ from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_xlm import XLMTokenizer
from .tokenization_roberta import RobertaTokenizer
from .tokenization_utils import (PreTrainedTokenizer)
from .modeling_auto import (AutoConfig, AutoModel)
@ -36,6 +38,8 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)

View File

@ -0,0 +1,181 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert RoBERTa checkpoint."""
from __future__ import absolute_import, division, print_function
import argparse
import logging
import numpy as np
import torch
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer
from pytorch_transformers.modeling_bert import (BertConfig, BertEncoder,
BertIntermediate, BertLayer,
BertModel, BertOutput,
BertSelfAttention,
BertSelfOutput)
from pytorch_transformers.modeling_roberta import (RobertaEmbeddings,
RobertaForMaskedLM,
RobertaForSequenceClassification,
RobertaModel)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = 'Hello world! cécé herlolip'
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head):
"""
Copy/paste/tweak roberta's weights to our BERT structure.
"""
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
roberta.eval() # disable dropout
config = BertConfig(
vocab_size_or_config_json_file=50265,
hidden_size=roberta.args.encoder_embed_dim,
num_hidden_layers=roberta.args.encoder_layers,
num_attention_heads=roberta.args.encoder_attention_heads,
intermediate_size=roberta.args.encoder_ffn_embed_dim,
max_position_embeddings=514,
type_vocab_size=1,
)
if classification_head:
config.num_labels = roberta.args.num_classes
print("Our BERT config:", config)
model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config)
model.eval()
# Now let's copy all the weights.
# Embeddings
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them.
model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
model.roberta.embeddings.LayerNorm.variance_epsilon = roberta_sent_encoder.emb_layer_norm.eps
for i in range(config.num_hidden_layers):
# Encoder: start of layer
layer: BertLayer = model.roberta.encoder.layer[i]
roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]
### self attention
self_attn: BertSelfAttention = layer.attention.self
assert(
roberta_layer.self_attn.in_proj_weight.shape == torch.Size((3 * config.hidden_size, config.hidden_size))
)
# we use three distinct linear layers so we split the source layer here.
self_attn.query.weight.data = roberta_layer.self_attn.in_proj_weight[:config.hidden_size, :]
self_attn.query.bias.data = roberta_layer.self_attn.in_proj_bias[:config.hidden_size]
self_attn.key.weight.data = roberta_layer.self_attn.in_proj_weight[config.hidden_size:2*config.hidden_size, :]
self_attn.key.bias.data = roberta_layer.self_attn.in_proj_bias[config.hidden_size:2*config.hidden_size]
self_attn.value.weight.data = roberta_layer.self_attn.in_proj_weight[2*config.hidden_size:, :]
self_attn.value.bias.data = roberta_layer.self_attn.in_proj_bias[2*config.hidden_size:]
### self-attention output
self_output: BertSelfOutput = layer.attention.output
assert(
self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
)
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
self_output.LayerNorm.variance_epsilon = roberta_layer.self_attn_layer_norm.eps
### intermediate
intermediate: BertIntermediate = layer.intermediate
assert(
intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
)
intermediate.dense.weight = roberta_layer.fc1.weight
intermediate.dense.bias = roberta_layer.fc1.bias
### output
bert_output: BertOutput = layer.output
assert(
bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
)
bert_output.dense.weight = roberta_layer.fc2.weight
bert_output.dense.bias = roberta_layer.fc2.bias
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
bert_output.LayerNorm.variance_epsilon = roberta_layer.final_layer_norm.eps
#### end of layer
if classification_head:
model.classifier.dense.weight = roberta.model.classification_heads['mnli'].dense.weight
model.classifier.dense.bias = roberta.model.classification_heads['mnli'].dense.bias
model.classifier.out_proj.weight = roberta.model.classification_heads['mnli'].out_proj.weight
model.classifier.out_proj.bias = roberta.model.classification_heads['mnli'].out_proj.bias
else:
# LM Head
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
model.lm_head.layer_norm.variance_epsilon = roberta.model.decoder.lm_head.layer_norm.eps
model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight
model.lm_head.bias = roberta.model.decoder.lm_head.bias
# Let's check that we get the same results.
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
our_output = model(input_ids)[0]
if classification_head:
their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids))
else:
their_output = roberta.model(input_ids)[0]
print(our_output.shape, their_output.shape)
success = torch.allclose(our_output, their_output, atol=1e-3)
print(
"Do both models output the same tensors?",
"🔥" if success else "💩"
)
if not success:
raise Exception("Something went wRoNg")
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--roberta_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the official PyTorch dump.")
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
parser.add_argument("--classification_head",
action = "store_true",
help = "Whether to convert a final classification head.")
args = parser.parse_args()
convert_roberta_checkpoint_to_pytorch(
args.roberta_checkpoint_path,
args.pytorch_dump_folder_path,
args.classification_head
)

View File

@ -0,0 +1,349 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch RoBERTa model. """
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from pytorch_transformers.modeling_bert import (BertConfig, BertEmbeddings,
BertLayerNorm, BertModel,
BertPreTrainedModel, gelu)
from pytorch_transformers.modeling_utils import add_start_docstrings
logger = logging.getLogger(__name__)
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-pytorch_model.bin",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-pytorch_model.bin",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin",
}
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
}
class RobertaEmbeddings(BertEmbeddings):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
def __init__(self, config):
super(RobertaEmbeddings, self).__init__(config)
self.padding_idx = 1
def forward(self, input_ids, token_type_ids=None, position_ids=None):
seq_length = input_ids.size(1)
if position_ids is None:
# Position numbers begin at padding_idx+1. Padding symbols are ignored.
# cf. fairseq's `utils.make_positions`
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
return super(RobertaEmbeddings, self).forward(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
class RobertaConfig(BertConfig):
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in
`RoBERTa: A Robustly Optimized BERT Pretraining Approach`_
by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer,
Veselin Stoyanov. It is based on Google's BERT model released in 2018.
It builds on BERT and modifies key hyperparameters, removing the next-sentence pretraining
objective and training with much larger mini-batches and learning rates.
This implementation is the same as BertModel with a tiny embeddings tweak as well as a setup for Roberta pretrained
models.
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`RoBERTa: A Robustly Optimized BERT Pretraining Approach`:
https://arxiv.org/abs/1907.11692
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~pytorch_transformers.RobertaConfig`): Model configuration class with all the parameters of the
model.
"""
ROBERTA_INPUTS_DOCSTRING = r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
To match pre-training, RoBERTa input sequence should be formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP][SEP] no it is not . [SEP]``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
Fully encoded sequences or sequence pairs can be obtained using the RobertaTokenizer.encode function with
the ``add_special_tokens`` parameter set to ``True``.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1[``.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@add_start_docstrings("The bare RoBERTa Model transformer outputing raw hidden-states without any specific head on top.",
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
class RobertaModel(BertModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the output of the last layer of the model.
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
def __init__(self, config):
super(RobertaModel, self).__init__(config)
self.embeddings = RobertaEmbeddings(config)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
if input_ids[:, 0].sum().item() != 0:
logger.warning("A sequence with no special tokens has been passed to the RoBERTa model. "
"This model requires special tokens in order to work. "
"Please specify add_special_tokens=True in your encoding.")
return super(RobertaModel, self).forward(input_ids, token_type_ids, attention_mask, position_ids, head_mask)
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
class RobertaForMaskedLM(BertPreTrainedModel):
r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForMaskedLM.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
"""
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
def __init__(self, config):
super(RobertaForMaskedLM, self).__init__(config)
self.roberta = RobertaModel(config)
self.lm_head = RobertaLMHead(config)
self.apply(self.init_weights)
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head.decoder, self.roberta.embeddings.word_embeddings)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, position_ids=None,
head_mask=None):
outputs = self.roberta(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
class RobertaLMHead(nn.Module):
"""Roberta Head for masked language modeling."""
def __init__(self, config):
super(RobertaLMHead, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = self.decoder(x) + self.bias
return x
@add_start_docstrings("""RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer
on top of the pooled output) e.g. for GLUE tasks. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
class RobertaForSequenceClassification(BertPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RoertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
def __init__(self, config):
super(RobertaForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.roberta = RobertaModel(config)
self.classifier = RobertaClassificationHead(config)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
outputs = self.roberta(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:]
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super(RobertaClassificationHead, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x

View File

@ -0,0 +1,242 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import shutil
import pytest
import torch
from pytorch_transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification)
from pytorch_transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
class RobertaModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (RobertaForMaskedLM, RobertaModel)
class RobertaModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = RobertaConfig(
vocab_size_or_config_json_file=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_and_check_roberta_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels,
token_labels, choice_labels):
model = RobertaModel(config=config)
model.eval()
sequence_output, pooled_output = model(input_ids, token_type_ids, input_mask)
sequence_output, pooled_output = model(input_ids, token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_roberta_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels,
token_labels, choice_labels):
model = RobertaForMaskedLM(config=config)
model.eval()
loss, prediction_scores = model(input_ids, token_type_ids, input_mask, token_labels)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_mask,
sequence_labels, token_labels, choice_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
return config, inputs_dict
def setUp(self):
self.model_tester = RobertaModelTest.RobertaModelTester(self)
self.config_tester = ConfigTester(self, config_class=RobertaConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_roberta_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_model(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_masked_lm(*config_and_inputs)
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = RobertaModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
class RobertaModelIntegrationTest(unittest.TestCase):
@pytest.mark.slow
def test_inference_masked_lm(self):
model = RobertaForMaskedLM.from_pretrained('roberta-base')
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 11, 50265))
self.assertEqual(
output.shape,
expected_shape
)
# compare the actual values for a slice.
expected_slice = torch.Tensor(
[[[33.8843, -4.3107, 22.7779],
[ 4.6533, -2.8099, 13.6252],
[ 1.8222, -3.6898, 8.8600]]]
)
self.assertTrue(
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
)
@pytest.mark.slow
def test_inference_no_head(self):
model = RobertaModel.from_pretrained('roberta-base')
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
# compare the actual values for a slice.
expected_slice = torch.Tensor(
[[[-0.0231, 0.0782, 0.0074],
[-0.1854, 0.0539, -0.0174],
[ 0.0548, 0.0799, 0.1687]]]
)
self.assertTrue(
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
)
@pytest.mark.slow
def test_inference_classification_head(self):
model = RobertaForSequenceClassification.from_pretrained('roberta-large-mnli')
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 3))
self.assertEqual(
output.shape,
expected_shape
)
expected_tensor = torch.Tensor([[-0.9469, 0.3913, 0.5118]])
self.assertTrue(
torch.allclose(output, expected_tensor, atol=1e-3)
)
if __name__ == "__main__":
unittest.main()

View File

@ -125,6 +125,17 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
self.assertFalse(_is_punctuation(u"A"))
self.assertFalse(_is_punctuation(u" "))
def test_sequence_builders(self):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
assert encoded_sentence == [101] + text + [102]
assert encoded_pair == [101] + text + [102] + text_2 + [102]
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,95 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import json
import unittest
from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases
class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = RobertaTokenizer
def setUp(self):
super(RobertaTokenizationTest, self).setUp()
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"lo", "low", "er",
"low", "lowest", "newer", "wider", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(self.vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens))
with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self):
return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
def get_input_output_texts(self):
input_text = u"lower newer"
output_text = u"lower<unk>newer"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower"
bpe_tokens = ["low", "er"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [13, 12, 17]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def roberta_dict_integration_testing(self):
tokenizer = self.get_tokenizer()
self.assertListEqual(
tokenizer.encode('Hello world!'),
[0, 31414, 232, 328, 2]
)
self.assertListEqual(
tokenizer.encode('Hello world! cécé herlolip 418'),
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
)
def test_sequence_builders(self):
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
assert encoded_sentence == encoded_text_from_decode
assert encoded_pair == encoded_pair_from_decode
if __name__ == '__main__':
unittest.main()

View File

@ -66,6 +66,17 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def test_sequence_builders(self):
tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048")
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
assert encoded_sentence == [1] + text + [1]
assert encoded_pair == [1] + text + [1] + text_2 + [1]
if __name__ == '__main__':
unittest.main()

View File

@ -89,6 +89,18 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
def test_sequence_builders(self):
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
assert encoded_sentence == text + [4, 3]
assert encoded_pair == text + [4] + text_2 + [4, 3]
if __name__ == '__main__':
unittest.main()

View File

@ -166,6 +166,22 @@ class BertTokenizer(PreTrainedTokenizer):
out_string = ' '.join(tokens).replace(' ##', '').strip()
return out_string
def add_special_tokens_single_sentence(self, token_ids):
"""
Adds special tokens to the a sequence for sequence classification tasks.
A BERT sequence has the following format: [CLS] X [SEP]
"""
return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
"""
Adds special tokens to a sequence pair for sequence classification tasks.
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
"""
sep = [self._convert_token_to_id(self.sep_token)]
cls = [self._convert_token_to_id(self.cls_token)]
return cls + token_ids_0 + sep + token_ids_1 + sep
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0

View File

@ -0,0 +1,201 @@
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for RoBERTa."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import sys
import json
import logging
import os
import regex as re
from io import open
from .tokenization_gpt2 import bytes_to_unicode, get_pairs
from .tokenization_utils import PreTrainedTokenizer
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json",
},
'merges_file':
{
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'roberta-base': 512,
'roberta-large': 512,
'roberta-large-mnli': 512,
}
class RobertaTokenizer(PreTrainedTokenizer):
"""
RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities: Byte-level BPE
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, errors='replace', bos_token="<s>", eos_token="</s>", sep_token="</s>",
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>', **kwargs):
super(RobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
mask_token=mask_token, **kwargs)
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
@property
def vocab_size(self):
return len(self.encoder)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def _tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token)
else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.decoder.get(index)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
text = ''.join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def add_special_tokens_single_sentence(self, token_ids):
"""
Adds special tokens to a sequence for sequence classification tasks.
A RoBERTa sequence has the following format: [CLS] X [SEP]
"""
return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
"""
Adds special tokens to a sequence pair for sequence classification tasks.
A RoBERTa sequence pair has the following format: [CLS] A [SEP][SEP] B [SEP]
"""
sep = [self._convert_token_to_id(self.sep_token)]
cls = [self._convert_token_to_id(self.cls_token)]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
return vocab_file, merge_file

View File

@ -180,9 +180,10 @@ class PreTrainedTokenizer(object):
@classmethod
def from_pretrained(cls, *inputs, **kwargs):
r""" Instantiate a :class:`~pytorch_transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer.
r"""
Instantiate a :class:`~pytorch_transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer.
Parameters:
Args:
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
@ -383,14 +384,15 @@ class PreTrainedTokenizer(object):
def add_tokens(self, new_tokens):
""" Add a list of new tokens to the tokenizer class. If the new tokens are not in the
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary.
Parameters:
new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Args:
new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Returns:
Number of tokens added to the vocabulary.
Returns:
Number of tokens added to the vocabulary.
Examples::
@ -422,17 +424,20 @@ class PreTrainedTokenizer(object):
def add_special_tokens(self, special_tokens_dict):
""" Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
to class attributes. If special tokens are NOT in the vocabulary, they are added
to it (indexed starting from the last index of the current vocabulary).
"""
Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
to class attributes. If special tokens are NOT in the vocabulary, they are added
to it (indexed starting from the last index of the current vocabulary).
Parameters:
special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``].
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Args:
special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes:
[``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,
``additional_special_tokens``].
Returns:
Number of tokens added to the vocabulary.
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Returns:
Number of tokens added to the vocabulary.
Examples::
@ -519,14 +524,37 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token):
raise NotImplementedError
def encode(self, text):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
def encode(self, text, text_pair=None, add_special_tokens=False):
"""
return self.convert_tokens_to_ids(self.tokenize(text))
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
Args:
text: The first sequence to be encoded.
text_pair: Optional second sequence to be encoded.
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model.
"""
if text_pair is None:
if add_special_tokens:
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text)))
else:
return self.convert_tokens_to_ids(self.tokenize(text))
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair)]
if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
else:
return first_sentence_tokens, second_sentence_tokens
def add_special_tokens_single_sentence(self, token_ids):
raise NotImplementedError
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
raise NotImplementedError
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
""" Converts a single index or a sequence of indices (integers) in a token "
@ -561,16 +589,28 @@ class PreTrainedTokenizer(object):
return ' '.join(self.convert_ids_to_tokens(tokens))
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
"""
Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
"""
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
text = self.convert_tokens_to_string(filtered_tokens)
if clean_up_tokenization_spaces:
text = self.clean_up_tokenization(text)
return text
if self.sep_token is not None and self.sep_token in text:
text = text.replace(self.cls_token, self.sep_token)
split_text = list(filter(lambda sentence: len(sentence) > 0, text.split(self.sep_token)))
if clean_up_tokenization_spaces:
clean_text = [self.clean_up_tokenization(text) for text in split_text]
return clean_text
else:
return split_text
else:
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
@property
def special_tokens_map(self):
@ -602,7 +642,7 @@ class PreTrainedTokenizer(object):
class attributes (cls_token, unk_token...).
"""
all_toks = self.all_special_tokens
all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks)
all_ids = list(self._convert_token_to_id(t) for t in all_toks)
return all_ids
@staticmethod

View File

@ -214,6 +214,22 @@ class XLMTokenizer(PreTrainedTokenizer):
out_string = ''.join(tokens).replace('</w>', ' ').strip()
return out_string
def add_special_tokens_single_sentence(self, token_ids):
"""
Adds special tokens to a sequence for sequence classification tasks.
An XLM sequence has the following format: [CLS] X [SEP]
"""
return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
"""
Adds special tokens to a sequence pair for sequence classification tasks.
An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP]
"""
sep = [self._convert_token_to_id(self.sep_token)]
cls = [self._convert_token_to_id(self.cls_token)]
return cls + token_ids_0 + sep + token_ids_1 + sep
def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(save_directory):

View File

@ -177,6 +177,24 @@ class XLNetTokenizer(PreTrainedTokenizer):
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
return out_string
def add_special_tokens_single_sentence(self, token_ids):
"""
Adds special tokens to a sequence pair for sequence classification tasks.
An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS]
"""
sep = [self._convert_token_to_id(self.sep_token)]
cls = [self._convert_token_to_id(self.cls_token)]
return token_ids + sep + cls
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
"""
Adds special tokens to a sequence for sequence classification tasks.
An XLNet sequence has the following format: X [SEP][CLS]
"""
sep = [self._convert_token_to_id(self.sep_token)]
cls = [self._convert_token_to_id(self.cls_token)]
return token_ids_0 + sep + token_ids_1 + sep + cls
def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory.