unified tokenizer api and serialization + tests

This commit is contained in:
thomwolf 2019-07-09 10:25:18 +02:00
parent 3d5f291386
commit b19786985d
34 changed files with 824 additions and 755 deletions

View File

@ -32,9 +32,11 @@ from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME from pytorch_transformers import (BertForSequenceClassification, XLNetForSequenceClassification,
from pytorch_transformers.modeling_bert import BertForSequenceClassification XLMForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
from pytorch_transformers.tokenization_bert import BertTokenizer XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
XLMTokenizer)
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics
@ -42,6 +44,21 @@ from utils_glue import processors, output_modes, convert_examples_to_features, c
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
MODEL_CLASSES = {
'bert': BertForSequenceClassification,
'xlnet': XLNetForSequenceClassification,
'xlm': XLMForSequenceClassification,
}
TOKENIZER_CLASSES = {
'bert': BertTokenizer,
'xlnet': XLNetTokenizer,
'xlm': XLMTokenizer,
}
def train(args, train_features, model): def train(args, train_features, model):
""" Train the model """ """ Train the model """
@ -156,7 +173,7 @@ def evalutate(args, eval_task, eval_output_dir, eval_features, model):
# Eval! # Eval!
logger.info("***** Running evaluation *****") logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Num examples = %d", len(eval_features))
logger.info(" Batch size = %d", args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size)
model.eval() model.eval()
eval_loss = 0 eval_loss = 0
@ -208,7 +225,7 @@ def load_and_cache_examples(args, task, tokenizer, eval=False):
examples = processor.get_dev_examples(args.data_dir) examples = processor.get_dev_examples(args.data_dir)
cached_features_file = os.path.join(args.data_dir, '{}_{}_{}_{}'.format( cached_features_file = os.path.join(args.data_dir, '{}_{}_{}_{}'.format(
'dev' if eval else 'train', 'dev' if eval else 'train',
list(filter(None, args.bert_model.split('/'))).pop(), list(filter(None, args.model_name.split('/'))).pop(),
str(args.max_seq_length), str(args.max_seq_length),
str(task))) str(task)))
@ -217,6 +234,11 @@ def load_and_cache_examples(args, task, tokenizer, eval=False):
features = torch.load(cached_features_file) features = torch.load(cached_features_file)
else: else:
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode) features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode)
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=bool(args.model_type not in ['bert', 'xlm']),
cls_token=tokenizer.cls_token,
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save(features, cached_features_file)
@ -230,12 +252,10 @@ def main():
## Required parameters ## Required parameters
parser.add_argument("--data_dir", default=None, type=str, required=True, parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.") help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--bert_model", default=None, type=str, required=True, parser.add_argument("--model_name", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, " help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--task_name", default=None, type=str, required=True, parser.add_argument("--task_name", default=None, type=str, required=True,
help="The name of the task to train.") help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
parser.add_argument("--output_dir", default=None, type=str, required=True, parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.") help="The output directory where the model predictions and checkpoints will be written.")
@ -243,9 +263,8 @@ def main():
parser.add_argument("--cache_dir", default="", type=str, parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3") help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length", default=128, type=int, parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n" help="The maximum total input sequence length after tokenization. Sequences longer "
"Sequences longer than this will be truncated, and sequences shorter \n" "than this will be truncated, sequences shorter will be padded.")
"than this will be padded.")
parser.add_argument("--do_train", action='store_true', parser.add_argument("--do_train", action='store_true',
help="Whether to run training.") help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', parser.add_argument("--do_eval", action='store_true',
@ -263,8 +282,7 @@ def main():
parser.add_argument("--num_train_epochs", default=3.0, type=float, parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.") help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", default=0.1, type=float, parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. " help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda", action='store_true', parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available") help="Avoid using CUDA when available")
parser.add_argument('--overwrite_output_dir', action='store_true', parser.add_argument('--overwrite_output_dir', action='store_true',
@ -331,8 +349,11 @@ def main():
# Make sure only the first process in distributed training will download model & vocab # Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier() torch.distributed.barrier()
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) args.model_type = args.model_name.lower().split('-')[0]
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) args.tokenizer_class = TOKENIZER_CLASSES[args.model_type]
args.model_class = MODEL_CLASSES[args.model_type]
tokenizer = args.tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case)
model = args.model_class.from_pretrained(args.model_name, num_labels=num_labels)
if args.local_rank == 0: if args.local_rank == 0:
torch.distributed.barrier() torch.distributed.barrier()
@ -359,27 +380,16 @@ def main():
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Save a trained model, configuration and tokenizer # Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model.save_pretrained(args.output_dir)
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(args.output_dir) tokenizer.save_vocabulary(args.output_dir)
# Load a trained model and vocabulary that you have fine-tuned
model = BertForSequenceClassification.from_pretrained(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
output_args_file = os.path.join(args.output_dir, 'training_args.bin') torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
torch.save(args, output_args_file)
else:
model = BertForSequenceClassification.from_pretrained(args.bert_model)
model.to(args.device) # Load a trained model and vocabulary that you have fine-tuned
model = args.model_class.from_pretrained(args.output_dir)
tokenizer = args.tokenizer_class.from_pretrained(args.output_dir)
model.to(args.device)
# Evaluation # Evaluation
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):

View File

@ -211,8 +211,8 @@ def main():
logger.info("No cache file at %s, preparing train features", cached_train_features_file) logger.info("No cache file at %s, preparing train features", cached_train_features_file)
train_features = convert_examples_to_features( train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer, output_mode, train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN, cls_token_at_end=True, cls_token=tokenizer.cls_token,
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2, sep_token=tokenizer.sep_token, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4) pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file) logger.info(" Saving train features into cached file %s", cached_train_features_file)
@ -369,8 +369,8 @@ def main():
logger.info("No cache file at %s, preparing eval features", cached_eval_features_file) logger.info("No cache file at %s, preparing eval features", cached_eval_features_file)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode, eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN, cls_token_at_end=True, cls_token=tokenizer.cls_token,
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2, sep_token=tokenizer.sep_token, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4) pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving eval features into cached file %s", cached_eval_features_file) logger.info(" Saving eval features into cached file %s", cached_eval_features_file)

View File

@ -396,7 +396,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
mask_padding_with_zero=True): mask_padding_with_zero=True):
""" Loads a data file into a list of `InputBatch`s """ Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token: `cls_token_at_end` define the location of the CLS token:
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP] - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
""" """
@ -489,8 +489,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
[str(x) for x in tokens])) [str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info( logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
logger.info("label: %s (id = %d)" % (example.label, label_id)) logger.info("label: %s (id = %d)" % (example.label, label_id))
features.append( features.append(

View File

@ -11,22 +11,28 @@ from .modeling_bert import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction, BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice, BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering, BertForTokenClassification, BertForQuestionAnswering,
load_tf_weights_in_bert) load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
load_tf_weights_in_openai_gpt) load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl) load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_gpt2 import (GPT2Config, GPT2Model, from .modeling_gpt2 import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2) load_tf_weights_in_gpt2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlnet import (XLNetConfig, from .modeling_xlnet import (XLNetConfig,
XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering, XLNetForSequenceClassification, XLNetForQuestionAnswering,
load_tf_weights_in_xlnet) load_tf_weights_in_xlnet, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlm import (XLMConfig, XLMModel, from .modeling_xlm import (XLMConfig, XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification, XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering) XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)

View File

@ -29,8 +29,7 @@ from pytorch_transformers.modeling_transfo_xl import (CONFIG_NAME,
TransfoXLConfig, TransfoXLConfig,
TransfoXLLMHeadModel, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl) load_tf_weights_in_transfo_xl)
from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
VOCAB_NAME)
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
@ -53,7 +52,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
with open(transfo_xl_dataset_file, "rb") as fp: with open(transfo_xl_dataset_file, "rb") as fp:
corpus = pickle.load(fp, encoding="latin1") corpus = pickle.load(fp, encoding="latin1")
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file']
print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
corpus_vocab_dict = corpus.vocab.__dict__ corpus_vocab_dict = corpus.vocab.__dict__
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)

View File

@ -24,7 +24,7 @@ import torch
import numpy import numpy
from pytorch_transformers.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel) from pytorch_transformers.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel)
from pytorch_transformers.tokenization_xlm import MERGES_NAME, VOCAB_NAME from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
@ -42,7 +42,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
# Save pytorch-model # Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file']
print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model, pytorch_weights_dump_path) torch.save(model, pytorch_weights_dump_path)

View File

@ -33,7 +33,7 @@ from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrai
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = { BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
@ -49,7 +49,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin", 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
} }
PRETRAINED_CONFIG_ARCHIVE_MAP = { BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
@ -152,7 +152,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(PretrainedConfig): class BertConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `BertModel`. """Configuration class to store the configuration of a `BertModel`.
""" """
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=30522, vocab_size_or_config_json_file=30522,
@ -543,7 +543,7 @@ class BertPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = BertConfig config_class = BertConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_bert load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert" base_model_prefix = "bert"

View File

@ -37,9 +37,9 @@ from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin", GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"} "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"} "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
@ -103,7 +103,7 @@ def gelu(x):
class GPT2Config(PretrainedConfig): class GPT2Config(PretrainedConfig):
"""Configuration class to store the configuration of a `GPT2Model`. """Configuration class to store the configuration of a `GPT2Model`.
""" """
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__( def __init__(
self, self,
@ -358,7 +358,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = GPT2Config config_class = GPT2Config
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_gpt2 load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer" base_model_prefix = "transformer"

View File

@ -37,8 +37,8 @@ from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"} OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"} OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
@ -130,7 +130,7 @@ ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
class OpenAIGPTConfig(PretrainedConfig): class OpenAIGPTConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `OpenAIGPTModel`. """Configuration class to store the configuration of a `OpenAIGPTModel`.
""" """
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__( def __init__(
self, self,
@ -384,7 +384,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = OpenAIGPTConfig config_class = OpenAIGPTConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_openai_gpt load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer" base_model_prefix = "transformer"

View File

@ -41,10 +41,10 @@ from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrai
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = { TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin", 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
} }
PRETRAINED_CONFIG_ARCHIVE_MAP = { TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
} }
@ -179,7 +179,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
class TransfoXLConfig(PretrainedConfig): class TransfoXLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `TransfoXLModel`. """Configuration class to store the configuration of a `TransfoXLModel`.
""" """
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=267735, vocab_size_or_config_json_file=267735,
@ -838,7 +838,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = TransfoXLConfig config_class = TransfoXLConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_transfo_xl load_tf_weights = load_tf_weights_in_transfo_xl
base_model_prefix = "transformer" base_model_prefix = "transformer"

View File

@ -169,6 +169,22 @@ class PreTrainedModel(nn.Module):
model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed
model_to_prune._prune_heads(heads_to_prune) model_to_prune._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):
""" Save a model with its configuration file to a directory, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
# Only save the model it-self if we are using distributed training
model_to_save = self.module if hasattr(self, 'module') else self
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
output_config_file = os.path.join(save_directory, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" """

View File

@ -40,10 +40,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = { XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin", 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin",
} }
PRETRAINED_CONFIG_ARCHIVE_MAP = { XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
} }
@ -51,7 +51,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
class XLMConfig(PretrainedConfig): class XLMConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `XLMModel`. """Configuration class to store the configuration of a `XLMModel`.
""" """
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=30145, vocab_size_or_config_json_file=30145,
@ -357,7 +357,7 @@ class XLMPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = XLMConfig config_class = XLMConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = None load_tf_weights = None
base_model_prefix = "transformer" base_model_prefix = "transformer"

View File

@ -38,10 +38,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = { XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin", 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin",
} }
PRETRAINED_CONFIG_ARCHIVE_MAP = { XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json", 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
} }
@ -195,7 +195,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class XLNetConfig(PretrainedConfig): class XLNetConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `XLNetModel`. """Configuration class to store the configuration of a `XLNetModel`.
""" """
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file=32000, vocab_size_or_config_json_file=32000,
@ -593,7 +593,7 @@ class XLNetPreTrainedModel(PreTrainedModel):
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = XLNetConfig config_class = XLNetConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_xlnet load_tf_weights = load_tf_weights_in_xlnet
base_model_prefix = "transformer" base_model_prefix = "transformer"

View File

@ -24,7 +24,7 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining, BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertForMultipleChoice) BertForTokenClassification, BertForMultipleChoice)
from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
@ -267,7 +267,7 @@ class BertModelTest(unittest.TestCase):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir) model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)

View File

@ -413,7 +413,7 @@ class GPTModelTester(object):
def create_and_check_model_from_pretrained(self): def create_and_check_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(self.base_model_class.PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir) model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
self.parent.assertIsNotNone(model) self.parent.assertIsNotNone(model)

View File

@ -26,7 +26,7 @@ import pytest
import torch import torch
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
@ -185,7 +185,7 @@ class TransfoXLModelTest(unittest.TestCase):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir) model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)

View File

@ -20,12 +20,12 @@ import unittest
import logging import logging
from pytorch_transformers import PretrainedConfig, PreTrainedModel from pytorch_transformers import PretrainedConfig, PreTrainedModel
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
class ModelUtilsTest(unittest.TestCase): class ModelUtilsTest(unittest.TestCase):
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = BertConfig.from_pretrained(model_name) config = BertConfig.from_pretrained(model_name)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, PretrainedConfig) self.assertIsInstance(config, PretrainedConfig)

View File

@ -21,7 +21,7 @@ import shutil
import pytest import pytest
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification) from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
@ -251,7 +251,7 @@ class XLMModelTest(unittest.TestCase):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir) model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)

View File

@ -26,7 +26,7 @@ import pytest
import torch import torch
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering) from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
@ -279,7 +279,7 @@ class XLNetModelTest(unittest.TestCase):
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_transformers_test/" cache_dir = "/tmp/pytorch_transformers_test/"
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir) model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir) shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)

View File

@ -17,14 +17,12 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from io import open from io import open
import shutil
import pytest
from pytorch_transformers.tokenization_bert import (BasicTokenizer, from pytorch_transformers.tokenization_bert import (BasicTokenizer,
BertTokenizer, BertTokenizer,
WordpieceTokenizer, WordpieceTokenizer,
_is_control, _is_punctuation, _is_control, _is_punctuation,
_is_whitespace) _is_whitespace, VOCAB_FILES_NAMES)
from .tokenization_tests_commons import create_and_check_tokenizer_commons from .tokenization_tests_commons import create_and_check_tokenizer_commons
@ -33,13 +31,15 @@ class TokenizationTest(unittest.TestCase):
def test_full_tokenizer(self): def test_full_tokenizer(self):
vocab_tokens = [ vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", "," "##ing", ",", "low", "lowest",
] ]
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer: vocab_directory = "/tmp/"
vocab_file = os.path.join(vocab_directory, VOCAB_FILES_NAMES['vocab_file'])
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name vocab_file = vocab_writer.name
create_and_check_tokenizer_commons(self, BertTokenizer, vocab_file) create_and_check_tokenizer_commons(self, BertTokenizer, pretrained_model_name_or_path=vocab_directory)
tokenizer = BertTokenizer(vocab_file) tokenizer = BertTokenizer(vocab_file)
@ -80,7 +80,7 @@ class TokenizationTest(unittest.TestCase):
vocab = {} vocab = {}
for (i, token) in enumerate(vocab_tokens): for (i, token) in enumerate(vocab_tokens):
vocab[token] = i vocab[token] = i
tokenizer = WordpieceTokenizer(vocab=vocab) tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
self.assertListEqual(tokenizer.tokenize(""), []) self.assertListEqual(tokenizer.tokenize(""), [])

View File

@ -17,8 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import json import json
import tempfile
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons from .tokenization_tests_commons import create_and_check_tokenizer_commons
@ -28,31 +29,31 @@ class GPT2TokenizationTest(unittest.TestCase):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"lo", "low", "er", "lo", "low", "er",
"low", "lowest", "newer", "wider"] "low", "lowest", "newer", "wider", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r", ""] merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp: special_tokens_map = {"unk_token": "<unk>"}
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
create_and_check_tokenizer_commons(self, GPT2Tokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"]) with tempfile.TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens))
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"]) create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map)
text = "lower"
bpe_tokens = ["low", "er"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"] tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
input_bpe_tokens = [13, 12, 16] text = "lower"
self.assertListEqual( bpe_tokens = ["low", "er"]
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
os.remove(vocab_file) input_tokens = tokens + [tokenizer.unk_token]
os.remove(merges_file) input_bpe_tokens = [13, 12, 17]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import json import json
import shutil import tempfile
import pytest
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
from.tokenization_tests_commons import create_and_check_tokenizer_commons from.tokenization_tests_commons import create_and_check_tokenizer_commons
@ -32,31 +31,31 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>", "w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>", "lo", "low", "er</w>",
"low</w>", "lowest</w>", "newer</w>", "wider</w>"] "low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""] merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"]) with tempfile.TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens))
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"]) create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname)
os.remove(vocab_file)
os.remove(merges_file)
text = "lower" tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"] text = "lower"
input_bpe_tokens = [14, 15, 20] bpe_tokens = ["low", "er</w>"]
self.assertListEqual( tokens = tokenizer.tokenize(text)
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import sys import sys
from io import open from io import open
import tempfile
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
unicode = str unicode = str
@ -28,22 +29,19 @@ else:
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs): def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class(*inputs, **kwargs) tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
vocab_path="/tmp/" with tempfile.TemporaryDirectory() as tmpdirname:
output_files = tokenizer.save_vocabulary(vocab_path=vocab_path) tokenizer.save_pretrained(tmpdirname)
tokenizer = tokenizer.from_pretrained(vocab_path) tokenizer = tokenizer.from_pretrained(tmpdirname)
for f in output_files:
os.remove(f)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
tester.assertListEqual(before_tokens, after_tokens) tester.assertListEqual(before_tokens, after_tokens)
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs): def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class(*inputs, **kwargs) tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
text = u"Munich and Berlin are nice cities" text = u"Munich and Berlin are nice cities"
filename = u"/tmp/tokenizer.bin" filename = u"/tmp/tokenizer.bin"
@ -58,8 +56,54 @@ def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs
tester.assertListEqual(subwords, subwords_loaded) tester.assertListEqual(subwords, subwords_loaded)
def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
vocab_size = tokenizer.vocab_size
all_size = len(tokenizer)
tester.assertNotEqual(vocab_size, 0)
tester.assertEqual(vocab_size, all_size)
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks)
vocab_size_2 = tokenizer.vocab_size
all_size_2 = len(tokenizer)
tester.assertNotEqual(vocab_size_2, 0)
tester.assertEqual(vocab_size, vocab_size_2)
tester.assertEqual(added_toks, len(new_toks))
tester.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
tester.assertGreaterEqual(len(tokens), 4)
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
'pad_token': "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer)
tester.assertNotEqual(vocab_size_3, 0)
tester.assertEqual(vocab_size, vocab_size_3)
tester.assertEqual(added_toks_2, len(new_toks_2))
tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
tester.assertGreaterEqual(len(tokens), 6)
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[0], tokens[1])
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
tester.assertGreater(tokens[-2], tokens[-3])
tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs): def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class(*inputs, **kwargs) tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
text = u"He is very happy, UNwant\u00E9d,running" text = u"He is very happy, UNwant\u00E9d,running"
tokens = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text)
@ -75,5 +119,6 @@ def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs): def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs) create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs) create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs) create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)

View File

@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
from io import open from io import open
import shutil import tempfile
import pytest
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
from.tokenization_tests_commons import create_and_check_tokenizer_commons from.tokenization_tests_commons import create_and_check_tokenizer_commons
@ -28,22 +27,23 @@ class TransfoXLTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self): def test_full_tokenizer(self):
vocab_tokens = [ vocab_tokens = [
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", "," "<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
"running", ",", "low", "l",
] ]
with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer: with tempfile.TemporaryDirectory() as tmpdirname:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
vocab_file = vocab_writer.name with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, vocab_file=vocab_file, lower_case=True) create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True)
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True) tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running") tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"]) self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
def test_full_tokenizer_lower(self): def test_full_tokenizer_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=True) tokenizer = TransfoXLTokenizer(lower_case=True)

View File

@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import six
from pytorch_transformers import PreTrainedTokenizer from pytorch_transformers import PreTrainedTokenizer
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
@ -27,8 +28,17 @@ class TokenizerUtilsTest(unittest.TestCase):
for model_name in s3_models[:1]: for model_name in s3_models[:1]:
tokenizer = tokenizer_class.from_pretrained(model_name) tokenizer = tokenizer_class.from_pretrained(model_name)
self.assertIsNotNone(tokenizer) self.assertIsNotNone(tokenizer)
self.assertIsInstance(tokenizer, tokenizer_class)
self.assertIsInstance(tokenizer, PreTrainedTokenizer) self.assertIsInstance(tokenizer, PreTrainedTokenizer)
for special_tok in tokenizer.all_special_tokens:
if six.PY2:
self.assertIsInstance(special_tok, unicode)
else:
self.assertIsInstance(special_tok, str)
special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
self.assertIsInstance(special_tok_id, int)
def test_pretrained_tokenizers(self): def test_pretrained_tokenizers(self):
self.check_tokenizer_from_pretrained(GPT2Tokenizer) self.check_tokenizer_from_pretrained(GPT2Tokenizer)

View File

@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import json import json
import shutil import tempfile
import pytest
from pytorch_transformers.tokenization_xlm import XLMTokenizer from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import create_and_check_tokenizer_commons from .tokenization_tests_commons import create_and_check_tokenizer_commons
@ -31,31 +30,31 @@ class XLMTokenizationTest(unittest.TestCase):
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
"w</w>", "r</w>", "t</w>", "w</w>", "r</w>", "t</w>",
"lo", "low", "er</w>", "lo", "low", "er</w>",
"low</w>", "lowest</w>", "newer</w>", "wider</w>"] "low</w>", "lowest</w>", "newer</w>", "wider</w>", "<unk>"]
vocab_tokens = dict(zip(vocab, range(len(vocab)))) vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""] merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp:
fp.write(json.dumps(vocab_tokens))
vocab_file = fp.name
with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp:
fp.write("\n".join(merges))
merges_file = fp.name
create_and_check_tokenizer_commons(self, XLMTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"]) with tempfile.TemporaryDirectory() as tmpdirname:
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
with open(vocab_file, "w") as fp:
fp.write(json.dumps(vocab_tokens))
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
tokenizer = XLMTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"]) create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname)
os.remove(vocab_file)
os.remove(merges_file)
text = "lower" tokenizer = XLMTokenizer(vocab_file, merges_file)
bpe_tokens = ["low", "er</w>"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"] text = "lower"
input_bpe_tokens = [14, 15, 20] bpe_tokens = ["low", "er</w>"]
self.assertListEqual( tokens = tokenizer.tokenize(text)
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + ["<unk>"]
input_bpe_tokens = [14, 15, 20]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -16,10 +16,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os import os
import unittest import unittest
import shutil import tempfile
import pytest
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE, VOCAB_FILES_NAMES)
from.tokenization_tests_commons import create_and_check_tokenizer_commons from.tokenization_tests_commons import create_and_check_tokenizer_commons
@ -29,34 +28,37 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
class XLNetTokenizationTest(unittest.TestCase): class XLNetTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self): def test_full_tokenizer(self):
create_and_check_tokenizer_commons(self, XLNetTokenizer, SAMPLE_VOCAB)
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokens = tokenizer.tokenize(u'This is a test') with tempfile.TemporaryDirectory() as tmpdirname:
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) tokenizer.save_pretrained(tmpdirname)
self.assertListEqual( create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0,
602, 347, 347, 347, 3, 12, 66,
46, 72, 80, 6, 0, 4])
back_tokens = tokenizer.convert_ids_to_tokens(ids) self.assertListEqual(
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
u'or', u'n', SPIECE_UNDERLINE + u'in',
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',', tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
u'<unk>', u'.']) u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0,
602, 347, 347, 347, 3, 12, 66,
46, 72, 80, 6, 0, 4])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in',
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
def test_tokenizer_lower(self): def test_tokenizer_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)

View File

@ -22,7 +22,6 @@ import os
import unicodedata import unicodedata
from io import open from io import open
from .file_utils import cached_path
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,20 +31,21 @@ VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': 'vocab_file':
{ {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
}} }
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-uncased': 512, 'bert-base-uncased': 512,
@ -93,8 +93,9 @@ class BertTokenizer(PreTrainedTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
mask_token="[MASK]", **kwargs):
"""Constructs a BertTokenizer. """Constructs a BertTokenizer.
Args: Args:
@ -102,17 +103,18 @@ class BertTokenizer(PreTrainedTokenizer):
do_lower_case: Whether to lower case the input do_lower_case: Whether to lower case the input
Only has an effect when do_wordpiece_only=False Only has an effect when do_wordpiece_only=False
do_basic_tokenize: Whether to do basic tokenization before wordpiece. do_basic_tokenize: Whether to do basic tokenization before wordpiece.
max_len: An artificial maximum length to truncate tokenized sequences to;
Effective maximum length is always the minimum of this
value (if specified) and the underlying BERT model's
sequence length.
never_split: List of tokens which will never be split during tokenization. never_split: List of tokens which will never be split during tokenization.
Only has an effect when do_wordpiece_only=False Only has an effect when do_wordpiece_only=False
""" """
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, **kwargs)
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
if never_split is None:
never_split = self.all_special_tokens
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict( self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()]) [(ids, tok) for tok, ids in self.vocab.items()])
@ -120,90 +122,34 @@ class BertTokenizer(PreTrainedTokenizer):
if do_basic_tokenize: if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split) never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
self.max_len = max_len if max_len is not None else int(1e12)
@property @property
def UNK_TOKEN(self): def vocab_size(self):
return "[UNK]" return len(self.vocab)
@property def _tokenize(self, text):
def SEP_TOKEN(self):
return "[SEP]"
@property
def PAD_TOKEN(self):
return "[PAD]"
@property
def CLS_TOKEN(self):
return "[CLS]"
@property
def MASK_TOKEN(self):
return "[MASK]"
@property
def UNK_ID(self):
return self.vocab["[UNK]"]
@property
def SEP_ID(self):
return self.vocab["[SEP]"]
@property
def PAD_ID(self):
return self.vocab["[PAD]"]
@property
def CLS_ID(self):
return self.vocab["[CLS]"]
@property
def MASK_ID(self):
return self.vocab["[MASK]"]
def tokenize(self, text):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text): for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for sub_token in self.wordpiece_tokenizer.tokenize(token): for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token) split_tokens.append(sub_token)
else: else:
split_tokens = self.wordpiece_tokenizer.tokenize(text) split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens return split_tokens
def convert_tokens_to_ids(self, tokens): def _convert_token_to_id(self, token):
"""Converts a sequence of tokens into ids using the vocab.""" """ Converts a token (str/unicode) in an id using the vocab. """
ids = [] return self.vocab.get(token, self.vocab.get(self.unk_token))
for token in tokens:
ids.append(self.vocab[token])
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids): def _convert_id_to_token(self, index):
"""Converts a sequence of ids in wordpiece tokens using the vocab.""" """Converts an index (integer) in a token (string/unicode) using the vocab."""
tokens = [] return self.ids_to_tokens.get(index, self.unk_token)
for i in ids:
tokens.append(self.ids_to_tokens[i])
return tokens
def encode(self, text): def _convert_ids_to_string(self, tokens_ids):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, token_ids, clean_up_tokenization_spaces=True):
"""Converts a sequence of ids in a string.""" """Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(token_ids) tokens = self.convert_ids_to_tokens(tokens_ids)
out_string = ''.join(tokens).replace(' ##', '').strip() out_string = ''.join(tokens).replace(' ##', '').strip()
if clean_up_tokenization_spaces:
for special_tok in (self.UNK_TOKEN, self.SEP_TOKEN, self.PAD_TOKEN, self.CLS_TOKEN, self.MASK_TOKEN):
out_string = out_string.replace(special_tok, '')
out_string = clean_up_tokenization(out_string)
return out_string return out_string
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
@ -245,17 +191,20 @@ class BasicTokenizer(object):
def __init__(self, def __init__(self,
do_lower_case=True, do_lower_case=True,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): never_split=None):
"""Constructs a BasicTokenizer. """Constructs a BasicTokenizer.
Args: Args:
do_lower_case: Whether to lower case the input. do_lower_case: Whether to lower case the input.
""" """
if never_split is None:
never_split = []
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
self.never_split = never_split self.never_split = never_split
def tokenize(self, text): def tokenize(self, text, never_split=None):
"""Tokenizes a piece of text.""" """Tokenizes a piece of text."""
never_split = self.never_split + (never_split if never_split is not None else [])
text = self._clean_text(text) text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese # This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't # models. This is also applied to the English models now, but it doesn't
@ -267,7 +216,7 @@ class BasicTokenizer(object):
orig_tokens = whitespace_tokenize(text) orig_tokens = whitespace_tokenize(text)
split_tokens = [] split_tokens = []
for token in orig_tokens: for token in orig_tokens:
if self.do_lower_case and token not in self.never_split: if self.do_lower_case and token not in never_split:
token = token.lower() token = token.lower()
token = self._run_strip_accents(token) token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token)) split_tokens.extend(self._run_split_on_punc(token))
@ -286,9 +235,9 @@ class BasicTokenizer(object):
output.append(char) output.append(char)
return "".join(output) return "".join(output)
def _run_split_on_punc(self, text): def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text.""" """Splits punctuation on a piece of text."""
if text in self.never_split: if never_split is not None and text in never_split:
return [text] return [text]
chars = list(text) chars = list(text)
i = 0 i = 0
@ -360,7 +309,7 @@ class BasicTokenizer(object):
class WordpieceTokenizer(object): class WordpieceTokenizer(object):
"""Runs WordPiece tokenization.""" """Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
self.vocab = vocab self.vocab = vocab
self.unk_token = unk_token self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word self.max_input_chars_per_word = max_input_chars_per_word

View File

@ -38,7 +38,6 @@ logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json', 'vocab_file': 'vocab.json',
'merges_file': 'merges.txt', 'merges_file': 'merges.txt',
'special_tokens_file': 'special_tokens.txt'
} }
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
@ -52,11 +51,6 @@ PRETRAINED_VOCAB_FILES_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
}, },
'special_tokens_file':
{
'gpt2': None,
'gpt2-medium': None,
}
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
@ -108,8 +102,10 @@ class GPT2Tokenizer(PreTrainedTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, errors='replace', max_len=None): def __init__(self, vocab_file, merges_file, errors='replace',
self.max_len = max_len if max_len is not None else int(1e12) bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs):
super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, **kwargs)
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
@ -123,32 +119,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
all_special_tokens = [] @property
if special_tokens_file is not None: def vocab_size(self):
special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] return len(self.encoder)
all_special_tokens.extend(special_tokens_to_add)
if special_tokens is not None and special_tokens:
all_special_tokens.extend(special_tokens)
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(all_special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token): def bpe(self, token):
if token in self.cache: if token in self.cache:
@ -191,7 +164,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
self.cache[token] = word self.cache[token] = word
return word return word
def tokenize(self, text): def _tokenize(self, text):
""" Tokenize a string. """ """ Tokenize a string. """
bpe_tokens = [] bpe_tokens = []
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
@ -202,57 +175,27 @@ class GPT2Tokenizer(PreTrainedTokenizer):
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens return bpe_tokens
def convert_tokens_to_ids(self, tokens): def _convert_token_to_id(self, token):
""" Converts a sequence of tokens into ids using the vocab. """ """ Converts a token (str/unicode) in an id using the vocab. """
ids = [] return self.encoder.get(token, self.encoder.get(self.unk_token))
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False): def _convert_id_to_token(self, index):
"""Converts a sequence of ids in BPE tokens using the vocab.""" """Converts an index (integer) in a token (string/unicode) using the vocab."""
tokens = [] return self.decoder.get(index, self.unk_token)
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text): def _convert_ids_to_string(self, tokens_ids):
return self.convert_tokens_to_ids(self.tokenize(text)) """Converts a sequence of ids in a string."""
text = ''.join(tokens_ids)
def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True):
text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
if clean_up_tokenization_spaces:
text = text.replace('<unk>', '')
text = clean_up_tokenization(text)
return text return text
def save_vocabulary(self, vocab_path): def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory.""" """Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file']) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file'])
with open(vocab_file, 'w', encoding='utf-8') as f: with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False)) f.write(json.dumps(self.encoder, ensure_ascii=False))
@ -268,14 +211,4 @@ class GPT2Tokenizer(PreTrainedTokenizer):
writer.write(' '.join(bpe_tokens) + u'\n') writer.write(' '.join(bpe_tokens) + u'\n')
index += 1 index += 1
index = len(self.encoder) return vocab_file, merge_file
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file

View File

@ -20,13 +20,9 @@ import json
import logging import logging
import os import os
import re import re
import sys
from io import open from io import open
from tqdm import tqdm from .tokenization_utils import PreTrainedTokenizer
from .file_utils import cached_path
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
from .tokenization_bert import BasicTokenizer from .tokenization_bert import BasicTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,7 +30,6 @@ logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json', 'vocab_file': 'vocab.json',
'merges_file': 'merges.txt', 'merges_file': 'merges.txt',
'special_tokens_file': 'special_tokens.txt'
} }
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
@ -46,10 +41,6 @@ PRETRAINED_VOCAB_FILES_MAP = {
{ {
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
}, },
'special_tokens_file':
{
'openai-gpt': None,
}
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
@ -88,14 +79,14 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
BPE tokenizer. Peculiarities: BPE tokenizer. Peculiarities:
- lower case all inputs - lower case all inputs
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, max_len=None): def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs)
try: try:
import ftfy import ftfy
import spacy import spacy
@ -103,11 +94,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
self.fix_text = ftfy.fix_text self.fix_text = ftfy.fix_text
except ImportError: except ImportError:
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
self.nlp = BasicTokenizer(do_lower_case=True, self.nlp = BasicTokenizer(do_lower_case=True)
never_split=special_tokens if special_tokens is not None else [])
self.fix_text = None self.fix_text = None
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
@ -115,35 +104,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {} self.cache = {}
all_special_tokens = [] @property
if special_tokens_file is not None: def vocab_size(self):
special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] return len(self.encoder)
all_special_tokens.extend(special_tokens_to_add)
if special_tokens is not None and special_tokens:
all_special_tokens.extend(special_tokens)
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(all_special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
if self.fix_text is None:
# Using BERT's BasicTokenizer: we can update the tokenizer
self.nlp.never_split = special_tokens
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token): def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + '</w>',) word = tuple(token[:-1]) + (token[-1] + '</w>',)
@ -188,7 +151,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
self.cache[token] = word self.cache[token] = word
return word return word
def tokenize(self, text): def _tokenize(self, text):
""" Tokenize a string. """ """ Tokenize a string. """
split_tokens = [] split_tokens = []
if self.fix_text is None: if self.fix_text is None:
@ -203,58 +166,26 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
return split_tokens return split_tokens
def convert_tokens_to_ids(self, tokens): def _convert_token_to_id(self, token):
""" Converts a sequence of tokens into ids using the vocab. """ """ Converts a token (str/unicode) in an id using the vocab. """
ids = [] return self.encoder.get(token, self.encoder.get(self.unk_token))
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False): def _convert_id_to_token(self, index):
"""Converts a sequence of ids in BPE tokens using the vocab.""" """Converts an id in a token (BPE) using the vocab."""
tokens = [] return self.decoder.get(index, self.unk_token)
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text): def _convert_ids_to_string(self, tokens_ids):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
"""Converts a sequence of ids in a string.""" """Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
out_string = ''.join(tokens).replace('</w>', ' ').strip()
if clean_up_tokenization_spaces:
out_string = out_string.replace('<unk>', '')
out_string = clean_up_tokenization(out_string)
return out_string return out_string
def save_vocabulary(self, vocab_path): def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory.""" """Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file']) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file'])
with open(vocab_file, 'w', encoding='utf-8') as f: with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False)) f.write(json.dumps(self.encoder, ensure_ascii=False))
@ -270,14 +201,4 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
writer.write(' '.join(bpe_tokens) + u'\n') writer.write(' '.join(bpe_tokens) + u'\n')
index += 1 index += 1
index = len(self.encoder) return vocab_file, merge_file
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file

View File

@ -41,7 +41,7 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin'} VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin', 'vocab_file': 'vocab.txt'}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'pretrained_vocab_file': 'pretrained_vocab_file':
@ -67,9 +67,17 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False, def __init__(self, special=None, min_freq=0, max_size=None, lower_case=False,
delimiter=None, vocab_file=None, pretrained_vocab_file=None, delimiter=None, vocab_file=None, pretrained_vocab_file=None,
never_split=("<unk>", "<eos>", "<formula>")): never_split=None, unk_token="<unk>", eos_token="<eos>",
additional_special_tokens=["<formula>"], **kwargs):
super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token,
additional_special_tokens=additional_special_tokens,
**kwargs)
if never_split is None:
never_split = self.all_special_tokens
if special is None:
special = []
self.counter = Counter() self.counter = Counter()
self.special = special self.special = special
self.min_freq = min_freq self.min_freq = min_freq
@ -200,11 +208,13 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.idx2sym.append(sym) self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1 self.sym2idx[sym] = len(self.idx2sym) - 1
def get_sym(self, idx): def _convert_id_to_token(self, idx):
"""Converts an id in a token (BPE) using the vocab."""
assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx) assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx)
return self.idx2sym[idx] return self.idx2sym[idx]
def get_idx(self, sym): def _convert_token_to_id(self, sym):
""" Converts a token (str/unicode) in an id using the vocab. """
if sym in self.sym2idx: if sym in self.sym2idx:
return self.sym2idx[sym] return self.sym2idx[sym]
else: else:
@ -220,36 +230,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
else: else:
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement') raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
def convert_ids_to_tokens(self, indices): def _convert_ids_to_string(self, tokens_ids):
"""Converts a sequence of indices in symbols using the vocab.""" """Converts a sequence of ids in a string."""
return [self.get_sym(idx) for idx in indices] out_string = ' '.join(tokens_ids).strip()
return out_string
def convert_tokens_to_ids(self, symbols):
"""Converts a sequence of symbols into ids using the vocab."""
return [self.get_idx(sym) for sym in symbols]
def convert_to_tensor(self, symbols): def convert_to_tensor(self, symbols):
return torch.LongTensor(self.convert_tokens_to_ids(symbols)) return torch.LongTensor(self.convert_tokens_to_ids(symbols))
def encode(self, text): @property
return self.convert_tokens_to_ids(self.tokenize(text)) def vocab_size(self):
def decode(self, indices, exclude=None, clean_up_tokenization_spaces=True):
"""Converts a sequence of indices in a string."""
if exclude is None:
out_string = ' '.join([self.get_sym(idx) for idx in indices])
else:
out_string = ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
if clean_up_tokenization_spaces:
out_string = clean_up_tokenization(out_string)
return out_string
def __len__(self):
return len(self.idx2sym) return len(self.idx2sym)
def tokenize(self, line, add_eos=False, add_double_eos=False): def _tokenize(self, line, add_eos=False, add_double_eos=False):
line = line.strip() line = line.strip()
# convert to lower case # convert to lower case
if self.lower_case: if self.lower_case:

View File

@ -16,37 +16,145 @@
from __future__ import (absolute_import, division, print_function, from __future__ import (absolute_import, division, print_function,
unicode_literals) unicode_literals)
import sys
import json
import logging import logging
import os import os
import regex as re import json
import six
from io import open from io import open
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
from .file_utils import cached_path from .file_utils import cached_path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
ADDED_TOKENS_FILE = 'added_tokens.json'
class PreTrainedTokenizer(object): class PreTrainedTokenizer(object):
""" An abstract class to handle dowloading and loading pretrained tokenizers. """ An abstract class to handle dowloading and loading pretrained tokenizers and adding tokens to the vocabulary.
Derived class can set up a few special tokens to be used in common scripts and internals:
bos_token, eos_token, EOP_TOKEN, EOD_TOKEN, unk_token, sep_token, pad_token, cls_token, mask_token
additional_special_tokens = []
We defined an added_tokens_encoder to add new tokens to the vocabulary without having to handle the
specific vocabulary augmentation methods of the various underlying dictionnary structures (BPE, sentencepiece...).
""" """
vocab_files_names = {} vocab_files_names = {}
pretrained_vocab_files_map = {} pretrained_vocab_files_map = {}
max_model_input_sizes = {} max_model_input_sizes = {}
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
"pad_token", "cls_token", "mask_token",
"additional_special_tokens"]
@property
def bos_token(self):
if self._bos_token is None:
logger.error("Using bos_token, but it is not set yet.")
return self._bos_token
@property
def eos_token(self):
if self._eos_token is None:
logger.error("Using eos_token, but it is not set yet.")
return self._eos_token
@property
def unk_token(self):
if self._unk_token is None:
logger.error("Using unk_token, but it is not set yet.")
return self._unk_token
@property
def sep_token(self):
if self._sep_token is None:
logger.error("Using sep_token, but it is not set yet.")
return self._sep_token
@property
def pad_token(self):
if self._pad_token is None:
logger.error("Using pad_token, but it is not set yet.")
return self._pad_token
@property
def cls_token(self):
if self._cls_token is None:
logger.error("Using cls_token, but it is not set yet.")
return self._cls_token
@property
def mask_token(self):
if self._mask_token is None:
logger.error("Using mask_token, but it is not set yet.")
return self._mask_token
@property
def additional_special_tokens(self):
if self._additional_special_tokens is None:
logger.error("Using additional_special_tokens, but it is not set yet.")
return self._additional_special_tokens
@bos_token.setter
def bos_token(self, value):
self._bos_token = value
@eos_token.setter
def eos_token(self, value):
self._eos_token = value
@unk_token.setter
def unk_token(self, value):
self._unk_token = value
@sep_token.setter
def sep_token(self, value):
self._sep_token = value
@pad_token.setter
def pad_token(self, value):
self._pad_token = value
@cls_token.setter
def cls_token(self, value):
self._cls_token = value
@mask_token.setter
def mask_token(self, value):
self._mask_token = value
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value
def __init__(self, max_len=None, **kwargs):
self._bos_token = None
self._eos_token = None
self._unk_token = None
self._sep_token = None
self._pad_token = None
self._cls_token = None
self._mask_token = None
self._additional_special_tokens = []
self.max_len = max_len if max_len is not None else int(1e12)
self.added_tokens_encoder = {}
self.added_tokens_decoder = {}
for key, value in kwargs.items():
if key not in self.SPECIAL_TOKENS_ATTRIBUTES:
raise ValueError(
"PreTrainedTokenizer.__init__() argument {} should be in {}".format(
key, ', '.join(self.SPECIAL_TOKENS_ATTRIBUTES)))
else:
setattr(self, key, value)
@classmethod @classmethod
def from_pretrained(cls, *inputs, **kwargs): def from_pretrained(cls, *inputs, **kwargs):
return cls._from_pretrained(*inputs, **kwargs) return cls._from_pretrained(*inputs, **kwargs)
@classmethod @classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
""" """
@ -59,16 +167,20 @@ class PreTrainedTokenizer(object):
for file_id, map_list in cls.pretrained_vocab_files_map.items(): for file_id, map_list in cls.pretrained_vocab_files_map.items():
vocab_files[file_id] = map_list[pretrained_model_name_or_path] vocab_files[file_id] = map_list[pretrained_model_name_or_path]
else: else:
for file_id, file_name in cls.vocab_files_names.items(): all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE}
all_vocab_files_names.update(cls.vocab_files_names)
for file_id, file_name in all_vocab_files_names.items():
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
full_file_name = os.path.join(pretrained_model_name_or_path, file_name) full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
else: else:
full_file_name = pretrained_model_name_or_path full_file_name = pretrained_model_name_or_path
if not os.path.exists(full_file_name): if not os.path.exists(full_file_name):
logger.info("Didn't find file {}. We don't load it.".format(full_file_name)) logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
full_file_name = None full_file_name = None
vocab_files[file_id] = full_file_name vocab_files[file_id] = full_file_name
# redirect to the cache, if necessary
# Get files from url, cache, or disk depending on the case
try: try:
resolved_vocab_files = {} resolved_vocab_files = {}
for file_id, file_path in vocab_files.items(): for file_id, file_path in vocab_files.items():
@ -95,6 +207,7 @@ class PreTrainedTokenizer(object):
logger.info("loading file {} from cache at {}".format( logger.info("loading file {} from cache at {}".format(
file_path, resolved_vocab_files[file_id])) file_path, resolved_vocab_files[file_id]))
# Set max length if needed
if pretrained_model_name_or_path in cls.max_model_input_sizes: if pretrained_model_name_or_path in cls.max_model_input_sizes:
# if we're using a pretrained model, ensure the tokenizer # if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings # wont index sequences longer than the number of positional embeddings
@ -102,31 +215,255 @@ class PreTrainedTokenizer(object):
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Merge resolved_vocab_files arguments in kwargs. # Merge resolved_vocab_files arguments in kwargs.
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
for args_name, file_path in resolved_vocab_files.items(): for args_name, file_path in resolved_vocab_files.items():
kwargs[args_name] = file_path if args_name not in kwargs:
kwargs[args_name] = file_path
if special_tokens_map_file is not None:
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
for key, value in special_tokens_map.items():
if key not in kwargs:
kwargs[key] = value
# Instantiate tokenizer. # Instantiate tokenizer.
tokenizer = cls(*inputs, **kwargs) tokenizer = cls(*inputs, **kwargs)
# Add supplementary tokens.
if added_tokens_file is not None:
added_tokens = json.load(open(added_tokens_file, encoding="utf-8"))
added_tok_encoder = dict((tok, len(tokenizer) + i) for i, tok in enumerate(added_tokens))
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
tokenizer.added_tokens_encoder.update(added_tok_encoder)
tokenizer.added_tokens_decoder.update(added_tok_decoder)
return tokenizer return tokenizer
def tokenize(self, text):
def save_pretrained(self, save_directory):
""" Save the tokenizer vocabulary files (with added tokens) and the
special-tokens-to-class-attributes-mapping to a directory, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
if not os.path.isdir(save_directory):
logger.error("Saving directory ({}) should be a directory".format(save_directory))
return
special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
with open(added_tokens_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.added_tokens_decoder, ensure_ascii=False))
vocab_files = self.save_vocabulary(save_directory)
return vocab_files + (special_tokens_map_file, added_tokens_file)
def save_vocabulary(self, save_directory):
""" Save the tokenizer vocabulary to a directory. This method doesn't save added tokens
and special token mappings.
Please use `save_pretrained()` to save the full Tokenizer state so that it can be
reloaded using the `from_pretrained(save_directory)` class method.
"""
raise NotImplementedError
def vocab_size(self):
raise NotImplementedError
def __len__(self):
return self.vocab_size + len(self.added_tokens_encoder)
def add_tokens(self, new_tokens):
""" Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to the added_tokens_encoder with indices starting from
the last index of the current vocabulary.
Returns:
Number of tokens added to the vocabulary which can be used to correspondingly
increase the size of the associated model embedding matrices.
"""
if not new_tokens:
return 0
to_add_tokens = []
for token in new_tokens:
if self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
to_add_tokens.append(token)
logger.info("Adding %s to the vocabulary", token)
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder)
return len(to_add_tokens)
def add_special_tokens(self, special_tokens_dict):
""" Add a dictionnary of special tokens (eos, pad, cls...) to the encoder and link them
to class attributes. If the special tokens are not in the vocabulary, they are added
to it and indexed starting from the last index of the current vocabulary.
Returns:
Number of tokens added to the vocabulary which can be used to correspondingly
increase the size of the associated model embedding matrices.
"""
if not special_tokens_dict:
return 0
added_special_tokens = self.add_tokens(special_tokens_dict.values())
for key, value in special_tokens_dict.items():
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
setattr(self, key, value)
return added_special_tokens
def tokenize(self, text, **kwargs):
""" Converts a string in a sequence of tokens (string), using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based
vocabularies (BPE/SentencePieces/WordPieces).
Take care of added tokens.
"""
def split_on_tokens(tok_list, text):
if not text:
return []
if not tok_list:
return self._tokenize(text, **kwargs)
tok = tok_list[0]
split_text = text.split(tok)
return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \
for sub_text in split_text), [])[:-1]
added_tokens = list(self.added_tokens_encoder.keys())
tokenized_text = split_on_tokens(added_tokens, text)
return tokenized_text
def _tokenize(self, text, **kwargs):
""" Converts a string in a sequence of tokens (string), using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based
vocabularies (BPE/SentencePieces/WordPieces).
Don't take care of added tokens.
"""
raise NotImplementedError raise NotImplementedError
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
""" Converts a single token or a sequence of tokens (str/unicode) in a integer id
(resp.) a sequence of ids, using the vocabulary.
"""
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
return self.convert_token_to_id_with_added_voc(tokens)
ids = []
for token in tokens:
ids.append(self.convert_token_to_id_with_added_voc(token))
if len(ids) > self.max_len:
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.max_len))
return ids
def convert_token_to_id_with_added_voc(self, token):
if token in self.added_tokens_encoder:
return self.added_tokens_encoder[token]
return self._convert_token_to_id(token)
def _convert_token_to_id(self, token):
raise NotImplementedError raise NotImplementedError
def convert_ids_to_tokens(self, ids):
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
""" Converts a single index or a sequence of indices (integers) in a token "
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
Args:
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
"""
if isinstance(ids, int):
return self.convert_id_to_token(ids)
tokens = []
for index in ids:
if index in self.all_special_ids and skip_special_tokens:
continue
if index in self.added_tokens_decoder:
tokens.append(self.added_tokens_decoder[index])
else:
tokens.append(self._convert_id_to_token(index))
return tokens
def _convert_id_to_token(self, index):
raise NotImplementedError raise NotImplementedError
def encode(self, text): def encode(self, text):
raise NotImplementedError """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
same as self.convert_tokens_to_ids(self.tokenize(text)).
"""
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, token_ids, *input, **kwargs):
raise NotImplementedError
def save_vocabulary(self, vocab_path): def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
raise NotImplementedError """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
"""
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
text = self._convert_ids_to_string(filtered_tokens)
if clean_up_tokenization_spaces:
text = clean_up_tokenization(text)
return text
def _convert_ids_to_string(self, tokens_ids):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary.
roughtly same as ' '.join(self.convert_ids_to_tokens(token_ids)).
"""
return ' '.join(self.convert_ids_to_tokens(tokens_ids))
@property
def special_tokens_map(self):
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
values ('<unk>', '<cls>'...)
"""
set_attr = {}
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
attr_value = getattr(self, "_" + attr)
if attr_value:
set_attr[attr] = attr_value
return set_attr
@property
def all_special_tokens(self):
""" List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
(cls_token, unk_token...).
"""
all_toks = []
set_attr = self.special_tokens_map
for attr_value in set_attr.values():
all_toks = all_toks + (attr_value if isinstance(attr_value, (list, tuple)) else [attr_value])
all_toks = list(set(all_toks))
return all_toks
@property
def all_special_ids(self):
""" List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to
class attributes (cls_token, unk_token...).
"""
all_toks = self.all_special_tokens
all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks)
return all_ids
def clean_up_tokenization(out_string): def clean_up_tokenization(out_string):

View File

@ -34,7 +34,6 @@ logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json', 'vocab_file': 'vocab.json',
'merges_file': 'merges.txt', 'merges_file': 'merges.txt',
'special_tokens_file': 'special_tokens.txt'
} }
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
@ -46,24 +45,12 @@ PRETRAINED_VOCAB_FILES_MAP = {
{ {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt", 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
}, },
'special_tokens_file':
{
'xlm-mlm-en-2048': None,
}
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlm-mlm-en-2048': 512, 'xlm-mlm-en-2048': 512,
} }
INDEX = {
"bos_index": 0,
"eos_index": 1,
"pad_index": 2,
"unk_index": 3,
"mask_index": 5
}
def get_pairs(word): def get_pairs(word):
""" """
Return set of symbol pairs in a word. Return set of symbol pairs in a word.
@ -103,7 +90,16 @@ class XLMTokenizer(PreTrainedTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, max_len=None): def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>",
sep_token="</s>", pad_token="<pad>", cls_token="</s>",
mask_token="<special1>", additional_special_tokens=["<special0>",
"<special1>", "<special2>", "<special3>", "<special4>", "<special5>",
"<special6>", "<special7>", "<special8>", "<special9>"], **kwargs):
super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token,
sep_token=sep_token, pad_token=pad_token,
cls_token=cls_token, mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs)
try: try:
import ftfy import ftfy
import spacy import spacy
@ -111,11 +107,9 @@ class XLMTokenizer(PreTrainedTokenizer):
self.fix_text = ftfy.fix_text self.fix_text = ftfy.fix_text
except ImportError: except ImportError:
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
self.nlp = BasicTokenizer(do_lower_case=True, self.nlp = BasicTokenizer(do_lower_case=True)
never_split=special_tokens if special_tokens is not None else [])
self.fix_text = None self.fix_text = None
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[:-1] merges = open(merges_file, encoding='utf-8').read().split('\n')[:-1]
@ -123,35 +117,9 @@ class XLMTokenizer(PreTrainedTokenizer):
self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {} self.cache = {}
all_special_tokens = [] @property
if special_tokens_file is not None: def vocab_size(self):
special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] return len(self.encoder)
all_special_tokens.extend(special_tokens_to_add)
if special_tokens is not None and special_tokens:
all_special_tokens.extend(special_tokens)
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(all_special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
if self.fix_text is None:
# Using BERT's BasicTokenizer: we can update the tokenizer
self.nlp.never_split = special_tokens
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token): def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + '</w>',) word = tuple(token[:-1]) + (token[-1] + '</w>',)
@ -196,7 +164,7 @@ class XLMTokenizer(PreTrainedTokenizer):
self.cache[token] = word self.cache[token] = word
return word return word
def tokenize(self, text): def _tokenize(self, text):
""" Tokenize a string. """ """ Tokenize a string. """
split_tokens = [] split_tokens = []
if self.fix_text is None: if self.fix_text is None:
@ -211,58 +179,26 @@ class XLMTokenizer(PreTrainedTokenizer):
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
return split_tokens return split_tokens
def convert_tokens_to_ids(self, tokens): def _convert_token_to_id(self, token):
""" Converts a sequence of tokens into ids using the vocab. """ """ Converts a token (str/unicode) in an id using the vocab. """
ids = [] return self.encoder.get(token, self.encoder.get(self.unk_token))
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False): def _convert_id_to_token(self, index):
"""Converts a sequence of ids in BPE tokens using the vocab.""" """Converts an index (integer) in a token (string/unicode) using the vocab."""
tokens = [] return self.decoder.get(index, self.unk_token)
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text): def _convert_ids_to_string(self, tokens_ids):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
"""Converts a sequence of ids in a string.""" """Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
out_string = ''.join(tokens).replace('</w>', ' ').strip()
if clean_up_tokenization_spaces:
out_string = out_string.replace('<unk>', '')
out_string = clean_up_tokenization(out_string)
return out_string return out_string
def save_vocabulary(self, vocab_path): def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory.""" """Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file']) merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file'])
with open(vocab_file, 'w', encoding='utf-8') as f: with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False)) f.write(json.dumps(self.encoder, ensure_ascii=False))
@ -277,14 +213,4 @@ class XLMTokenizer(PreTrainedTokenizer):
writer.write(' '.join(bpe_tokens) + u'\n') writer.write(' '.join(bpe_tokens) + u'\n')
index += 1 index += 1
index = len(self.encoder) return vocab_file, merge_file
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file

View File

@ -16,17 +16,13 @@
from __future__ import (absolute_import, division, print_function, from __future__ import (absolute_import, division, print_function,
unicode_literals) unicode_literals)
import json
import logging import logging
import os import os
import sys
from shutil import copyfile from shutil import copyfile
from io import open
import unicodedata import unicodedata
import six import six
from .file_utils import cached_path
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,8 +40,6 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlnet-large-cased': 512, 'xlnet-large-cased': 512,
} }
VOCAB_NAME = 'spiece.model'
SPIECE_UNDERLINE = u'' SPIECE_UNDERLINE = u''
# Segments (not really needed) # Segments (not really needed)
@ -60,31 +54,26 @@ class XLNetTokenizer(PreTrainedTokenizer):
SentencePiece based tokenizer. Peculiarities: SentencePiece based tokenizer. Peculiarities:
- requires SentencePiece: https://github.com/google/sentencepiece - requires SentencePiece: https://github.com/google/sentencepiece
""" """
# Tokens
special_symbols = {
"<unk>" : 0,
"<s>" : 1,
"</s>" : 2,
"<cls>" : 3,
"<sep>" : 4,
"<pad>" : 5,
"<mask>" : 6,
"<eod>" : 7,
"<eop>" : 8,
}
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, max_len=None, def __init__(self, vocab_file, max_len=None,
do_lower_case=False, remove_space=True, keep_accents=False): do_lower_case=False, remove_space=True, keep_accents=False,
bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>",
pad_token="<pad>", cls_token="<cls>", mask_token="<mask>",
additional_special_tokens=["<eop>", "<eod>"], **kwargs):
super(XLNetTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token,
unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, additional_special_tokens=
additional_special_tokens, **kwargs)
try: try:
import sentencepiece as spm import sentencepiece as spm
except ImportError: except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece") "pip install sentencepiece")
self.max_len = max_len if max_len is not None else int(1e12)
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
self.remove_space = remove_space self.remove_space = remove_space
self.keep_accents = keep_accents self.keep_accents = keep_accents
@ -94,46 +83,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
self.sp_model.Load(vocab_file) self.sp_model.Load(vocab_file)
@property @property
def UNK_TOKEN(self): def vocab_size(self):
return "<unk>"
@property
def SEP_TOKEN(self):
return "<sep>"
@property
def PAD_TOKEN(self):
return "<pad>"
@property
def CLS_TOKEN(self):
return "<cls>"
@property
def MASK_TOKEN(self):
return "<mask>"
@property
def UNK_ID(self):
return self.special_symbols["<unk>"]
@property
def SEP_ID(self):
return self.special_symbols["<sep>"]
@property
def PAD_ID(self):
return self.special_symbols["<pad>"]
@property
def CLS_ID(self):
return self.special_symbols["<cls>"]
@property
def MASK_ID(self):
return self.special_symbols["<mask>"]
def __len__(self):
return len(self.sp_model) return len(self.sp_model)
def __getstate__(self): def __getstate__(self):
@ -169,7 +119,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
return outputs return outputs
def tokenize(self, text, return_unicode=True, sample=False): def _tokenize(self, text, return_unicode=True, sample=False):
""" Tokenize a string. """ Tokenize a string.
return_unicode is used only for py2 return_unicode is used only for py2
""" """
@ -208,56 +158,30 @@ class XLNetTokenizer(PreTrainedTokenizer):
return new_pieces return new_pieces
def convert_tokens_to_ids(self, tokens, sample=False): def _convert_token_to_id(self, token):
""" Converts a sequence of tokens into ids using the vocab. """ """ Converts a token (str/unicode) in an id using the vocab. """
ids = [] return self.sp_model.PieceToId(token)
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
return self.sp_model.PieceToId(tokens)
for token in tokens:
ids.append(self.sp_model.PieceToId(token))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this XLNet model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, return_unicode=True): def _convert_id_to_token(self, index, return_unicode=True):
"""Converts a sequence of ids in tokens.""" """Converts an index (integer) in a token (string/unicode) using the vocab."""
tokens = [] token = self.sp_model.IdToPiece(index)
for i in ids: if six.PY2 and return_unicode and isinstance(token, str):
tokens.append(self.sp_model.IdToPiece(i)) token = token.decode('utf-8')
return token
if six.PY2 and return_unicode: def _convert_ids_to_string(self, tokens_ids):
ret_pieces = []
for piece in tokens:
if isinstance(piece, str):
piece = piece.decode('utf-8')
ret_pieces.append(piece)
tokens = ret_pieces
return tokens
def encode(self, text, sample=False):
return self.convert_tokens_to_ids(self.tokenize(text, sample=sample))
def decode(self, ids, clean_up_tokenization_spaces=True):
"""Converts a sequence of ids in a string.""" """Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids) out_string = ''.join(tokens_ids)
out_string = ''.join(tokens)
if clean_up_tokenization_spaces:
out_string = out_string.strip().replace('<unk>', '')
out_string = clean_up_tokenization(out_string)
return out_string return out_string
def save_vocabulary(self, vocab_path): def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file """ Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory. to a directory.
""" """
if not os.path.isdir(vocab_path): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return return
out_vocab_file = os.path.join(vocab_path, VOCAB_NAME) out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)