diff --git a/examples/run_glue.py b/examples/run_glue.py index 8dd845a5532..59583ed712c 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -32,9 +32,11 @@ from torch.utils.data.distributed import DistributedSampler from tensorboardX import SummaryWriter -from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME -from pytorch_transformers.modeling_bert import BertForSequenceClassification -from pytorch_transformers.tokenization_bert import BertTokenizer +from pytorch_transformers import (BertForSequenceClassification, XLNetForSequenceClassification, + XLMForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, + XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP) +from pytorch_transformers import (BertTokenizer, XLNetTokenizer, + XLMTokenizer) from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule from utils_glue import processors, output_modes, convert_examples_to_features, compute_metrics @@ -42,6 +44,21 @@ from utils_glue import processors, output_modes, convert_examples_to_features, c logger = logging.getLogger(__name__) +ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP, + XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, + XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ()) + +MODEL_CLASSES = { + 'bert': BertForSequenceClassification, + 'xlnet': XLNetForSequenceClassification, + 'xlm': XLMForSequenceClassification, +} + +TOKENIZER_CLASSES = { + 'bert': BertTokenizer, + 'xlnet': XLNetTokenizer, + 'xlm': XLMTokenizer, +} def train(args, train_features, model): """ Train the model """ @@ -156,7 +173,7 @@ def evalutate(args, eval_task, eval_output_dir, eval_features, model): # Eval! logger.info("***** Running evaluation *****") - logger.info(" Num examples = %d", len(eval_examples)) + logger.info(" Num examples = %d", len(eval_features)) logger.info(" Batch size = %d", args.eval_batch_size) model.eval() eval_loss = 0 @@ -208,7 +225,7 @@ def load_and_cache_examples(args, task, tokenizer, eval=False): examples = processor.get_dev_examples(args.data_dir) cached_features_file = os.path.join(args.data_dir, '{}_{}_{}_{}'.format( 'dev' if eval else 'train', - list(filter(None, args.bert_model.split('/'))).pop(), + list(filter(None, args.model_name.split('/'))).pop(), str(args.max_seq_length), str(task))) @@ -217,6 +234,11 @@ def load_and_cache_examples(args, task, tokenizer, eval=False): features = torch.load(cached_features_file) else: features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode) + features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode, + cls_token_at_end=bool(args.model_type not in ['bert', 'xlm']), + cls_token=tokenizer.cls_token, + sep_token=tokenizer.sep_token, cls_token_segment_id=2, + pad_on_left=True, pad_token_segment_id=4) if args.local_rank == -1 or torch.distributed.get_rank() == 0: logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) @@ -230,12 +252,10 @@ def main(): ## Required parameters parser.add_argument("--data_dir", default=None, type=str, required=True, help="The input data dir. Should contain the .tsv files (or other data files) for the task.") - parser.add_argument("--bert_model", default=None, type=str, required=True, - help="Bert pre-trained model selected in the list: bert-base-uncased, " - "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " - "bert-base-multilingual-cased, bert-base-chinese.") + parser.add_argument("--model_name", default=None, type=str, required=True, + help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS)) parser.add_argument("--task_name", default=None, type=str, required=True, - help="The name of the task to train.") + help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.") @@ -243,9 +263,8 @@ def main(): parser.add_argument("--cache_dir", default="", type=str, help="Where do you want to store the pre-trained models downloaded from s3") parser.add_argument("--max_seq_length", default=128, type=int, - help="The maximum total input sequence length after WordPiece tokenization. \n" - "Sequences longer than this will be truncated, and sequences shorter \n" - "than this will be padded.") + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', @@ -263,8 +282,7 @@ def main(): parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") parser.add_argument("--warmup_proportion", default=0.1, type=float, - help="Proportion of training to perform linear learning rate warmup for. " - "E.g., 0.1 = 10%% of training.") + help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).") parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") parser.add_argument('--overwrite_output_dir', action='store_true', @@ -331,8 +349,11 @@ def main(): # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() - tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) - model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels) + args.model_type = args.model_name.lower().split('-')[0] + args.tokenizer_class = TOKENIZER_CLASSES[args.model_type] + args.model_class = MODEL_CLASSES[args.model_type] + tokenizer = args.tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case) + model = args.model_class.from_pretrained(args.model_name, num_labels=num_labels) if args.local_rank == 0: torch.distributed.barrier() @@ -359,27 +380,16 @@ def main(): # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): # Save a trained model, configuration and tokenizer - model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - - # If we save using the predefined names, we can load using `from_pretrained` - output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) - output_config_file = os.path.join(args.output_dir, CONFIG_NAME) - - torch.save(model_to_save.state_dict(), output_model_file) - model_to_save.config.to_json_file(output_config_file) + model.save_pretrained(args.output_dir) tokenizer.save_vocabulary(args.output_dir) - # Load a trained model and vocabulary that you have fine-tuned - model = BertForSequenceClassification.from_pretrained(args.output_dir) - tokenizer = BertTokenizer.from_pretrained(args.output_dir) - # Good practice: save your training arguments together with the trained model - output_args_file = os.path.join(args.output_dir, 'training_args.bin') - torch.save(args, output_args_file) - else: - model = BertForSequenceClassification.from_pretrained(args.bert_model) + torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) - model.to(args.device) + # Load a trained model and vocabulary that you have fine-tuned + model = args.model_class.from_pretrained(args.output_dir) + tokenizer = args.tokenizer_class.from_pretrained(args.output_dir) + model.to(args.device) # Evaluation if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): diff --git a/examples/run_xlnet_classifier.py b/examples/run_xlnet_classifier.py index 7cf8a8d8778..35b0ebfbd1a 100644 --- a/examples/run_xlnet_classifier.py +++ b/examples/run_xlnet_classifier.py @@ -211,8 +211,8 @@ def main(): logger.info("No cache file at %s, preparing train features", cached_train_features_file) train_features = convert_examples_to_features( train_examples, label_list, args.max_seq_length, tokenizer, output_mode, - cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN, - sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2, + cls_token_at_end=True, cls_token=tokenizer.cls_token, + sep_token=tokenizer.sep_token, cls_token_segment_id=2, pad_on_left=True, pad_token_segment_id=4) if args.local_rank == -1 or torch.distributed.get_rank() == 0: logger.info(" Saving train features into cached file %s", cached_train_features_file) @@ -369,8 +369,8 @@ def main(): logger.info("No cache file at %s, preparing eval features", cached_eval_features_file) eval_features = convert_examples_to_features( eval_examples, label_list, args.max_seq_length, tokenizer, output_mode, - cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN, - sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2, + cls_token_at_end=True, cls_token=tokenizer.cls_token, + sep_token=tokenizer.sep_token, cls_token_segment_id=2, pad_on_left=True, pad_token_segment_id=4) if args.local_rank == -1 or torch.distributed.get_rank() == 0: logger.info(" Saving eval features into cached file %s", cached_eval_features_file) diff --git a/examples/utils_glue.py b/examples/utils_glue.py index 18e733567dc..47505929573 100644 --- a/examples/utils_glue.py +++ b/examples/utils_glue.py @@ -396,7 +396,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, mask_padding_with_zero=True): """ Loads a data file into a list of `InputBatch`s `cls_token_at_end` define the location of the CLS token: - - False (BERT pattern): [CLS] + A + [SEP] + B + [SEP] + - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) """ @@ -489,8 +489,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, [str(x) for x in tokens])) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) - logger.info( - "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) + logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) logger.info("label: %s (id = %d)" % (example.label, label_id)) features.append( diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index 6dd78dfd025..c8f64a07def 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -11,22 +11,28 @@ from .modeling_bert import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering, - load_tf_weights_in_bert) + load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, + BERT_PRETRAINED_CONFIG_ARCHIVE_MAP) from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, - load_tf_weights_in_openai_gpt) + load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, + OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, - load_tf_weights_in_transfo_xl) + load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, + TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_gpt2 import (GPT2Config, GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, - load_tf_weights_in_gpt2) + load_tf_weights_in_gpt2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, + GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_xlnet import (XLNetConfig, XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering, - load_tf_weights_in_xlnet) + load_tf_weights_in_xlnet, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, + XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_xlm import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, - XLMForQuestionAnswering) + XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, + XLM_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) diff --git a/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py b/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py index 2d666a1f038..db23e5bffe6 100755 --- a/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py +++ b/pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py @@ -29,8 +29,7 @@ from pytorch_transformers.modeling_transfo_xl import (CONFIG_NAME, TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl) -from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, - VOCAB_NAME) +from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES) if sys.version_info[0] == 2: import cPickle as pickle @@ -53,7 +52,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, with open(transfo_xl_dataset_file, "rb") as fp: corpus = pickle.load(fp, encoding="latin1") # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) - pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME + pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file'] print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) corpus_vocab_dict = corpus.vocab.__dict__ torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) diff --git a/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py b/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py index 0cbe962cead..e5815252f12 100755 --- a/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py +++ b/pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py @@ -24,7 +24,7 @@ import torch import numpy from pytorch_transformers.modeling_xlm import (CONFIG_NAME, WEIGHTS_NAME, XLMConfig, XLMModel) -from pytorch_transformers.tokenization_xlm import MERGES_NAME, VOCAB_NAME +from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): @@ -42,7 +42,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p # Save pytorch-model pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME - pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME + pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) torch.save(model, pytorch_weights_dump_path) diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index b2a456209d1..0dd72b29696 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -33,7 +33,7 @@ from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrai logger = logging.getLogger(__name__) -PRETRAINED_MODEL_ARCHIVE_MAP = { +BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", @@ -49,7 +49,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin", } -PRETRAINED_CONFIG_ARCHIVE_MAP = { +BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", @@ -152,7 +152,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} class BertConfig(PretrainedConfig): """Configuration class to store the configuration of a `BertModel`. """ - pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP + pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP def __init__(self, vocab_size_or_config_json_file=30522, @@ -543,7 +543,7 @@ class BertPreTrainedModel(PreTrainedModel): a simple interface for dowloading and loading pretrained models. """ config_class = BertConfig - pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP + pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP load_tf_weights = load_tf_weights_in_bert base_model_prefix = "bert" diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index 090763cda18..9340ce84895 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -37,9 +37,9 @@ from .modeling_bert import BertLayerNorm as LayerNorm logger = logging.getLogger(__name__) -PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin", +GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"} -PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", +GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"} def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): @@ -103,7 +103,7 @@ def gelu(x): class GPT2Config(PretrainedConfig): """Configuration class to store the configuration of a `GPT2Model`. """ - pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP + pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP def __init__( self, @@ -358,7 +358,7 @@ class GPT2PreTrainedModel(PreTrainedModel): a simple interface for dowloading and loading pretrained models. """ config_class = GPT2Config - pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP + pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP load_tf_weights = load_tf_weights_in_gpt2 base_model_prefix = "transformer" diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index b715b183713..4a3ff732f6c 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -37,8 +37,8 @@ from .modeling_bert import BertLayerNorm as LayerNorm logger = logging.getLogger(__name__) -PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"} -PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"} +OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"} +OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"} def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): @@ -130,7 +130,7 @@ ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu} class OpenAIGPTConfig(PretrainedConfig): """Configuration class to store the configuration of a `OpenAIGPTModel`. """ - pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP + pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP def __init__( self, @@ -384,7 +384,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): a simple interface for dowloading and loading pretrained models. """ config_class = OpenAIGPTConfig - pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP + pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP load_tf_weights = load_tf_weights_in_openai_gpt base_model_prefix = "transformer" diff --git a/pytorch_transformers/modeling_transfo_xl.py b/pytorch_transformers/modeling_transfo_xl.py index 465577b0028..35a1b635f91 100644 --- a/pytorch_transformers/modeling_transfo_xl.py +++ b/pytorch_transformers/modeling_transfo_xl.py @@ -41,10 +41,10 @@ from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrai logger = logging.getLogger(__name__) -PRETRAINED_MODEL_ARCHIVE_MAP = { +TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = { 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin", } -PRETRAINED_CONFIG_ARCHIVE_MAP = { +TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = { 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", } @@ -179,7 +179,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path): class TransfoXLConfig(PretrainedConfig): """Configuration class to store the configuration of a `TransfoXLModel`. """ - pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP + pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP def __init__(self, vocab_size_or_config_json_file=267735, @@ -838,7 +838,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel): a simple interface for dowloading and loading pretrained models. """ config_class = TransfoXLConfig - pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP + pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP load_tf_weights = load_tf_weights_in_transfo_xl base_model_prefix = "transformer" diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 96558704ea3..b9be1a38133 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -169,6 +169,22 @@ class PreTrainedModel(nn.Module): model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed model_to_prune._prune_heads(heads_to_prune) + def save_pretrained(self, save_directory): + """ Save a model with its configuration file to a directory, so that it + can be re-loaded using the `from_pretrained(save_directory)` class method. + """ + assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" + + # Only save the model it-self if we are using distributed training + model_to_save = self.module if hasattr(self, 'module') else self + + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = os.path.join(save_directory, WEIGHTS_NAME) + output_config_file = os.path.join(save_directory, CONFIG_NAME) + + torch.save(model_to_save.state_dict(), output_model_file) + model_to_save.config.to_json_file(output_config_file) + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): """ diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 14f8848a42d..c7ea294dbd4 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -40,10 +40,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra logger = logging.getLogger(__name__) -PRETRAINED_MODEL_ARCHIVE_MAP = { +XLM_PRETRAINED_MODEL_ARCHIVE_MAP = { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-pytorch_model.bin", } -PRETRAINED_CONFIG_ARCHIVE_MAP = { +XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", } @@ -51,7 +51,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = { class XLMConfig(PretrainedConfig): """Configuration class to store the configuration of a `XLMModel`. """ - pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP + pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP def __init__(self, vocab_size_or_config_json_file=30145, @@ -357,7 +357,7 @@ class XLMPreTrainedModel(PreTrainedModel): a simple interface for dowloading and loading pretrained models. """ config_class = XLMConfig - pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP + pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP load_tf_weights = None base_model_prefix = "transformer" diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index 289dcbd9db6..628dbe74508 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -38,10 +38,10 @@ from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTra logger = logging.getLogger(__name__) -PRETRAINED_MODEL_ARCHIVE_MAP = { +XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = { 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin", } -PRETRAINED_CONFIG_ARCHIVE_MAP = { +XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json", } @@ -195,7 +195,7 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} class XLNetConfig(PretrainedConfig): """Configuration class to store the configuration of a `XLNetModel`. """ - pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP + pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP def __init__(self, vocab_size_or_config_json_file=32000, @@ -593,7 +593,7 @@ class XLNetPreTrainedModel(PreTrainedModel): a simple interface for dowloading and loading pretrained models. """ config_class = XLNetConfig - pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP + pretrained_model_archive_map = XLNET_PRETRAINED_MODEL_ARCHIVE_MAP load_tf_weights = load_tf_weights_in_xlnet base_model_prefix = "transformer" diff --git a/pytorch_transformers/tests/modeling_bert_test.py b/pytorch_transformers/tests/modeling_bert_test.py index 2ba59317be0..fbdce293663 100644 --- a/pytorch_transformers/tests/modeling_bert_test.py +++ b/pytorch_transformers/tests/modeling_bert_test.py @@ -24,7 +24,7 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM, BertForNextSentencePrediction, BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, BertForMultipleChoice) -from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP +from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) @@ -267,7 +267,7 @@ class BertModelTest(unittest.TestCase): @pytest.mark.slow def test_model_from_pretrained(self): cache_dir = "/tmp/pytorch_transformers_test/" - for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: model = BertModel.from_pretrained(model_name, cache_dir=cache_dir) shutil.rmtree(cache_dir) self.assertIsNotNone(model) diff --git a/pytorch_transformers/tests/modeling_tests_commons.py b/pytorch_transformers/tests/modeling_tests_commons.py index b831f85552c..db79b017c1e 100644 --- a/pytorch_transformers/tests/modeling_tests_commons.py +++ b/pytorch_transformers/tests/modeling_tests_commons.py @@ -413,7 +413,7 @@ class GPTModelTester(object): def create_and_check_model_from_pretrained(self): cache_dir = "/tmp/pytorch_transformers_test/" - for model_name in list(self.base_model_class.PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]: model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir) shutil.rmtree(cache_dir) self.parent.assertIsNotNone(model) diff --git a/pytorch_transformers/tests/modeling_transfo_xl_test.py b/pytorch_transformers/tests/modeling_transfo_xl_test.py index f2906d879fe..49ba1addf14 100644 --- a/pytorch_transformers/tests/modeling_transfo_xl_test.py +++ b/pytorch_transformers/tests/modeling_transfo_xl_test.py @@ -26,7 +26,7 @@ import pytest import torch from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) -from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP +from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor @@ -185,7 +185,7 @@ class TransfoXLModelTest(unittest.TestCase): @pytest.mark.slow def test_model_from_pretrained(self): cache_dir = "/tmp/pytorch_transformers_test/" - for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir) shutil.rmtree(cache_dir) self.assertIsNotNone(model) diff --git a/pytorch_transformers/tests/modeling_utils_test.py b/pytorch_transformers/tests/modeling_utils_test.py index 5e3b8e676a2..a168c246114 100644 --- a/pytorch_transformers/tests/modeling_utils_test.py +++ b/pytorch_transformers/tests/modeling_utils_test.py @@ -20,12 +20,12 @@ import unittest import logging from pytorch_transformers import PretrainedConfig, PreTrainedModel -from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP +from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP class ModelUtilsTest(unittest.TestCase): def test_model_from_pretrained(self): logging.basicConfig(level=logging.INFO) - for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: config = BertConfig.from_pretrained(model_name) self.assertIsNotNone(config) self.assertIsInstance(config, PretrainedConfig) diff --git a/pytorch_transformers/tests/modeling_xlm_test.py b/pytorch_transformers/tests/modeling_xlm_test.py index 9c511f21a86..6e2e082d194 100644 --- a/pytorch_transformers/tests/modeling_xlm_test.py +++ b/pytorch_transformers/tests/modeling_xlm_test.py @@ -21,7 +21,7 @@ import shutil import pytest from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification) -from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP +from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) @@ -251,7 +251,7 @@ class XLMModelTest(unittest.TestCase): @pytest.mark.slow def test_model_from_pretrained(self): cache_dir = "/tmp/pytorch_transformers_test/" - for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir) shutil.rmtree(cache_dir) self.assertIsNotNone(model) diff --git a/pytorch_transformers/tests/modeling_xlnet_test.py b/pytorch_transformers/tests/modeling_xlnet_test.py index b762426d2c7..e167e2d2e82 100644 --- a/pytorch_transformers/tests/modeling_xlnet_test.py +++ b/pytorch_transformers/tests/modeling_xlnet_test.py @@ -26,7 +26,7 @@ import pytest import torch from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering) -from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP +from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor @@ -279,7 +279,7 @@ class XLNetModelTest(unittest.TestCase): @pytest.mark.slow def test_model_from_pretrained(self): cache_dir = "/tmp/pytorch_transformers_test/" - for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir) shutil.rmtree(cache_dir) self.assertIsNotNone(model) diff --git a/pytorch_transformers/tests/tokenization_bert_test.py b/pytorch_transformers/tests/tokenization_bert_test.py index 37e20cc2865..220bf453467 100644 --- a/pytorch_transformers/tests/tokenization_bert_test.py +++ b/pytorch_transformers/tests/tokenization_bert_test.py @@ -17,14 +17,12 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import unittest from io import open -import shutil -import pytest from pytorch_transformers.tokenization_bert import (BasicTokenizer, - BertTokenizer, - WordpieceTokenizer, - _is_control, _is_punctuation, - _is_whitespace) + BertTokenizer, + WordpieceTokenizer, + _is_control, _is_punctuation, + _is_whitespace, VOCAB_FILES_NAMES) from .tokenization_tests_commons import create_and_check_tokenizer_commons @@ -33,13 +31,15 @@ class TokenizationTest(unittest.TestCase): def test_full_tokenizer(self): vocab_tokens = [ "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", - "##ing", "," + "##ing", ",", "low", "lowest", ] - with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer: + vocab_directory = "/tmp/" + vocab_file = os.path.join(vocab_directory, VOCAB_FILES_NAMES['vocab_file']) + with open(vocab_file, "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_file = vocab_writer.name - create_and_check_tokenizer_commons(self, BertTokenizer, vocab_file) + create_and_check_tokenizer_commons(self, BertTokenizer, pretrained_model_name_or_path=vocab_directory) tokenizer = BertTokenizer(vocab_file) @@ -80,7 +80,7 @@ class TokenizationTest(unittest.TestCase): vocab = {} for (i, token) in enumerate(vocab_tokens): vocab[token] = i - tokenizer = WordpieceTokenizer(vocab=vocab) + tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") self.assertListEqual(tokenizer.tokenize(""), []) diff --git a/pytorch_transformers/tests/tokenization_gpt2_test.py b/pytorch_transformers/tests/tokenization_gpt2_test.py index 8b06161b53c..30959ceed1d 100644 --- a/pytorch_transformers/tests/tokenization_gpt2_test.py +++ b/pytorch_transformers/tests/tokenization_gpt2_test.py @@ -17,8 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import unittest import json +import tempfile -from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer +from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES from .tokenization_tests_commons import create_and_check_tokenizer_commons @@ -28,31 +29,31 @@ class GPT2TokenizationTest(unittest.TestCase): """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "lo", "low", "er", - "low", "lowest", "newer", "wider"] + "low", "lowest", "newer", "wider", ""] vocab_tokens = dict(zip(vocab, range(len(vocab)))) merges = ["#version: 0.2", "l o", "lo w", "e r", ""] - with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp: - fp.write(json.dumps(vocab_tokens)) - vocab_file = fp.name - with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp: - fp.write("\n".join(merges)) - merges_file = fp.name + special_tokens_map = {"unk_token": ""} - create_and_check_tokenizer_commons(self, GPT2Tokenizer, vocab_file, merges_file, special_tokens=["", ""]) + with tempfile.TemporaryDirectory() as tmpdirname: + vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) + merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) + with open(vocab_file, "w") as fp: + fp.write(json.dumps(vocab_tokens)) + with open(merges_file, "w") as fp: + fp.write("\n".join(merges)) - tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["", ""]) - text = "lower" - bpe_tokens = ["low", "er"] - tokens = tokenizer.tokenize(text) - self.assertListEqual(tokens, bpe_tokens) + create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map) - input_tokens = tokens + [""] - input_bpe_tokens = [13, 12, 16] - self.assertListEqual( - tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map) + text = "lower" + bpe_tokens = ["low", "er"] + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, bpe_tokens) - os.remove(vocab_file) - os.remove(merges_file) + input_tokens = tokens + [tokenizer.unk_token] + input_bpe_tokens = [13, 12, 17] + self.assertListEqual( + tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) if __name__ == '__main__': diff --git a/pytorch_transformers/tests/tokenization_openai_test.py b/pytorch_transformers/tests/tokenization_openai_test.py index 3f8c49f8886..22f7d700176 100644 --- a/pytorch_transformers/tests/tokenization_openai_test.py +++ b/pytorch_transformers/tests/tokenization_openai_test.py @@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import unittest import json -import shutil -import pytest +import tempfile -from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer +from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES from.tokenization_tests_commons import create_and_check_tokenizer_commons @@ -32,31 +31,31 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "w", "r", "t", "lo", "low", "er", - "low", "lowest", "newer", "wider"] + "low", "lowest", "newer", "wider", ""] vocab_tokens = dict(zip(vocab, range(len(vocab)))) merges = ["#version: 0.2", "l o", "lo w", "e r", ""] - with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp: - fp.write(json.dumps(vocab_tokens)) - vocab_file = fp.name - with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp: - fp.write("\n".join(merges)) - merges_file = fp.name - create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, vocab_file, merges_file, special_tokens=["", ""]) + with tempfile.TemporaryDirectory() as tmpdirname: + vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) + merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) + with open(vocab_file, "w") as fp: + fp.write(json.dumps(vocab_tokens)) + with open(merges_file, "w") as fp: + fp.write("\n".join(merges)) - tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["", ""]) - os.remove(vocab_file) - os.remove(merges_file) + create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname) - text = "lower" - bpe_tokens = ["low", "er"] - tokens = tokenizer.tokenize(text) - self.assertListEqual(tokens, bpe_tokens) + tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file) - input_tokens = tokens + [""] - input_bpe_tokens = [14, 15, 20] - self.assertListEqual( - tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + text = "lower" + bpe_tokens = ["low", "er"] + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, bpe_tokens) + + input_tokens = tokens + [""] + input_bpe_tokens = [14, 15, 20] + self.assertListEqual( + tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) if __name__ == '__main__': diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index 876f7747be0..07f962bcab5 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import sys from io import open +import tempfile if sys.version_info[0] == 3: unicode = str @@ -28,22 +29,19 @@ else: def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs): - tokenizer = tokenizer_class(*inputs, **kwargs) + tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") - vocab_path="/tmp/" - output_files = tokenizer.save_vocabulary(vocab_path=vocab_path) - tokenizer = tokenizer.from_pretrained(vocab_path) - - for f in output_files: - os.remove(f) + with tempfile.TemporaryDirectory() as tmpdirname: + tokenizer.save_pretrained(tmpdirname) + tokenizer = tokenizer.from_pretrained(tmpdirname) after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") tester.assertListEqual(before_tokens, after_tokens) def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs): - tokenizer = tokenizer_class(*inputs, **kwargs) + tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) text = u"Munich and Berlin are nice cities" filename = u"/tmp/tokenizer.bin" @@ -58,8 +56,54 @@ def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs tester.assertListEqual(subwords, subwords_loaded) +def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs): + tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) + + vocab_size = tokenizer.vocab_size + all_size = len(tokenizer) + + tester.assertNotEqual(vocab_size, 0) + tester.assertEqual(vocab_size, all_size) + + new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"] + added_toks = tokenizer.add_tokens(new_toks) + vocab_size_2 = tokenizer.vocab_size + all_size_2 = len(tokenizer) + + tester.assertNotEqual(vocab_size_2, 0) + tester.assertEqual(vocab_size, vocab_size_2) + tester.assertEqual(added_toks, len(new_toks)) + tester.assertEqual(all_size_2, all_size + len(new_toks)) + + tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l") + tester.assertGreaterEqual(len(tokens), 4) + tester.assertGreater(tokens[0], tokenizer.vocab_size - 1) + tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1) + + new_toks_2 = {'eos_token': ">>>>|||<||<<|<<", + 'pad_token': "<<<<<|||>|>>>>|>"} + added_toks_2 = tokenizer.add_special_tokens(new_toks_2) + vocab_size_3 = tokenizer.vocab_size + all_size_3 = len(tokenizer) + + tester.assertNotEqual(vocab_size_3, 0) + tester.assertEqual(vocab_size, vocab_size_3) + tester.assertEqual(added_toks_2, len(new_toks_2)) + tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) + + tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l") + + tester.assertGreaterEqual(len(tokens), 6) + tester.assertGreater(tokens[0], tokenizer.vocab_size - 1) + tester.assertGreater(tokens[0], tokens[1]) + tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1) + tester.assertGreater(tokens[-2], tokens[-3]) + tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token)) + tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token)) + + def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs): - tokenizer = tokenizer_class(*inputs, **kwargs) + tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) text = u"He is very happy, UNwant\u00E9d,running" tokens = tokenizer.tokenize(text) @@ -75,5 +119,6 @@ def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs): create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs) + create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs) create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs) create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs) diff --git a/pytorch_transformers/tests/tokenization_transfo_xl_test.py b/pytorch_transformers/tests/tokenization_transfo_xl_test.py index f583e30b56b..a4ddd357b9d 100644 --- a/pytorch_transformers/tests/tokenization_transfo_xl_test.py +++ b/pytorch_transformers/tests/tokenization_transfo_xl_test.py @@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import unittest from io import open -import shutil -import pytest +import tempfile -from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer +from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES from.tokenization_tests_commons import create_and_check_tokenizer_commons @@ -28,22 +27,23 @@ class TransfoXLTokenizationTest(unittest.TestCase): def test_full_tokenizer(self): vocab_tokens = [ - "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", "," + "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", + "running", ",", "low", "l", ] - with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer: - vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - vocab_file = vocab_writer.name + with tempfile.TemporaryDirectory() as tmpdirname: + vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) + with open(vocab_file, "w", encoding='utf-8') as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - create_and_check_tokenizer_commons(self, TransfoXLTokenizer, vocab_file=vocab_file, lower_case=True) + create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True) - tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True) - os.remove(vocab_file) + tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True) - tokens = tokenizer.tokenize(u" UNwanted , running") - self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) + tokens = tokenizer.tokenize(u" UNwanted , running") + self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) - self.assertListEqual( - tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) def test_full_tokenizer_lower(self): tokenizer = TransfoXLTokenizer(lower_case=True) diff --git a/pytorch_transformers/tests/tokenization_utils_test.py b/pytorch_transformers/tests/tokenization_utils_test.py index e8856d50c2c..26ec2d7a394 100644 --- a/pytorch_transformers/tests/tokenization_utils_test.py +++ b/pytorch_transformers/tests/tokenization_utils_test.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function import unittest +import six from pytorch_transformers import PreTrainedTokenizer from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer @@ -27,8 +28,17 @@ class TokenizerUtilsTest(unittest.TestCase): for model_name in s3_models[:1]: tokenizer = tokenizer_class.from_pretrained(model_name) self.assertIsNotNone(tokenizer) + self.assertIsInstance(tokenizer, tokenizer_class) self.assertIsInstance(tokenizer, PreTrainedTokenizer) + for special_tok in tokenizer.all_special_tokens: + if six.PY2: + self.assertIsInstance(special_tok, unicode) + else: + self.assertIsInstance(special_tok, str) + special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) + self.assertIsInstance(special_tok_id, int) + def test_pretrained_tokenizers(self): self.check_tokenizer_from_pretrained(GPT2Tokenizer) diff --git a/pytorch_transformers/tests/tokenization_xlm_test.py b/pytorch_transformers/tests/tokenization_xlm_test.py index 00d273a628f..b543ed23f87 100644 --- a/pytorch_transformers/tests/tokenization_xlm_test.py +++ b/pytorch_transformers/tests/tokenization_xlm_test.py @@ -17,10 +17,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import unittest import json -import shutil -import pytest +import tempfile -from pytorch_transformers.tokenization_xlm import XLMTokenizer +from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES from .tokenization_tests_commons import create_and_check_tokenizer_commons @@ -31,31 +30,31 @@ class XLMTokenizationTest(unittest.TestCase): vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "w", "r", "t", "lo", "low", "er", - "low", "lowest", "newer", "wider"] + "low", "lowest", "newer", "wider", ""] vocab_tokens = dict(zip(vocab, range(len(vocab)))) merges = ["l o 123", "lo w 1456", "e r 1789", ""] - with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp: - fp.write(json.dumps(vocab_tokens)) - vocab_file = fp.name - with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp: - fp.write("\n".join(merges)) - merges_file = fp.name - create_and_check_tokenizer_commons(self, XLMTokenizer, vocab_file, merges_file, special_tokens=["", ""]) + with tempfile.TemporaryDirectory() as tmpdirname: + vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) + merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) + with open(vocab_file, "w") as fp: + fp.write(json.dumps(vocab_tokens)) + with open(merges_file, "w") as fp: + fp.write("\n".join(merges)) - tokenizer = XLMTokenizer(vocab_file, merges_file, special_tokens=["", ""]) - os.remove(vocab_file) - os.remove(merges_file) + create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname) - text = "lower" - bpe_tokens = ["low", "er"] - tokens = tokenizer.tokenize(text) - self.assertListEqual(tokens, bpe_tokens) + tokenizer = XLMTokenizer(vocab_file, merges_file) - input_tokens = tokens + [""] - input_bpe_tokens = [14, 15, 20] - self.assertListEqual( - tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + text = "lower" + bpe_tokens = ["low", "er"] + tokens = tokenizer.tokenize(text) + self.assertListEqual(tokens, bpe_tokens) + + input_tokens = tokens + [""] + input_bpe_tokens = [14, 15, 20] + self.assertListEqual( + tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) if __name__ == '__main__': diff --git a/pytorch_transformers/tests/tokenization_xlnet_test.py b/pytorch_transformers/tests/tokenization_xlnet_test.py index 6e81f214b76..8fc98209ba4 100644 --- a/pytorch_transformers/tests/tokenization_xlnet_test.py +++ b/pytorch_transformers/tests/tokenization_xlnet_test.py @@ -16,10 +16,9 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import unittest -import shutil -import pytest +import tempfile -from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) +from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE, VOCAB_FILES_NAMES) from.tokenization_tests_commons import create_and_check_tokenizer_commons @@ -29,34 +28,37 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), class XLNetTokenizationTest(unittest.TestCase): def test_full_tokenizer(self): - create_and_check_tokenizer_commons(self, XLNetTokenizer, SAMPLE_VOCAB) - tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) - tokens = tokenizer.tokenize(u'This is a test') - self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) + with tempfile.TemporaryDirectory() as tmpdirname: + tokenizer.save_pretrained(tmpdirname) - self.assertListEqual( - tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) + create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname) - tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") - self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', - u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', - u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', - SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) - ids = tokenizer.convert_tokens_to_ids(tokens) - self.assertListEqual( - ids, [8, 21, 84, 55, 24, 19, 7, 0, - 602, 347, 347, 347, 3, 12, 66, - 46, 72, 80, 6, 0, 4]) + tokens = tokenizer.tokenize(u'This is a test') + self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) - back_tokens = tokenizer.convert_ids_to_tokens(ids) - self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', - u'or', u'n', SPIECE_UNDERLINE + u'in', - SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', - SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', - SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', - u'', u'.']) + self.assertListEqual( + tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) + + tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") + self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', + u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', + u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', + SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) + ids = tokenizer.convert_tokens_to_ids(tokens) + self.assertListEqual( + ids, [8, 21, 84, 55, 24, 19, 7, 0, + 602, 347, 347, 347, 3, 12, 66, + 46, 72, 80, 6, 0, 4]) + + back_tokens = tokenizer.convert_ids_to_tokens(ids) + self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', + u'or', u'n', SPIECE_UNDERLINE + u'in', + SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', + SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', + SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', + u'', u'.']) def test_tokenizer_lower(self): tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index b26e5066e93..3e14673f464 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -22,7 +22,6 @@ import os import unicodedata from io import open -from .file_utils import cached_path from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization logger = logging.getLogger(__name__) @@ -32,20 +31,21 @@ VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} PRETRAINED_VOCAB_FILES_MAP = { 'vocab_file': { - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", - 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", - 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", - 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", - 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", - 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", - 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", -}} + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", + 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", + 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", + 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", + 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", + 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", + 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", + } +} PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'bert-base-uncased': 512, @@ -93,8 +93,9 @@ class BertTokenizer(PreTrainedTokenizer): pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, - never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, + unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", + mask_token="[MASK]", **kwargs): """Constructs a BertTokenizer. Args: @@ -102,17 +103,18 @@ class BertTokenizer(PreTrainedTokenizer): do_lower_case: Whether to lower case the input Only has an effect when do_wordpiece_only=False do_basic_tokenize: Whether to do basic tokenization before wordpiece. - max_len: An artificial maximum length to truncate tokenized sequences to; - Effective maximum length is always the minimum of this - value (if specified) and the underlying BERT model's - sequence length. never_split: List of tokens which will never be split during tokenization. Only has an effect when do_wordpiece_only=False """ + super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, + pad_token=pad_token, cls_token=cls_token, + mask_token=mask_token, **kwargs) if not os.path.isfile(vocab_file): raise ValueError( "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) + if never_split is None: + never_split = self.all_special_tokens self.vocab = load_vocab(vocab_file) self.ids_to_tokens = collections.OrderedDict( [(ids, tok) for tok, ids in self.vocab.items()]) @@ -120,90 +122,34 @@ class BertTokenizer(PreTrainedTokenizer): if do_basic_tokenize: self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, never_split=never_split) - self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) - self.max_len = max_len if max_len is not None else int(1e12) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) @property - def UNK_TOKEN(self): - return "[UNK]" + def vocab_size(self): + return len(self.vocab) - @property - def SEP_TOKEN(self): - return "[SEP]" - - @property - def PAD_TOKEN(self): - return "[PAD]" - - @property - def CLS_TOKEN(self): - return "[CLS]" - - @property - def MASK_TOKEN(self): - return "[MASK]" - - @property - def UNK_ID(self): - return self.vocab["[UNK]"] - - @property - def SEP_ID(self): - return self.vocab["[SEP]"] - - @property - def PAD_ID(self): - return self.vocab["[PAD]"] - - @property - def CLS_ID(self): - return self.vocab["[CLS]"] - - @property - def MASK_ID(self): - return self.vocab["[MASK]"] - - def tokenize(self, text): + def _tokenize(self, text): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text): + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): for sub_token in self.wordpiece_tokenizer.tokenize(token): split_tokens.append(sub_token) else: split_tokens = self.wordpiece_tokenizer.tokenize(text) return split_tokens - def convert_tokens_to_ids(self, tokens): - """Converts a sequence of tokens into ids using the vocab.""" - ids = [] - for token in tokens: - ids.append(self.vocab[token]) - if len(ids) > self.max_len: - logger.warning( - "Token indices sequence length is longer than the specified maximum " - " sequence length for this BERT model ({} > {}). Running this" - " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) - ) - return ids + def _convert_token_to_id(self, token): + """ Converts a token (str/unicode) in an id using the vocab. """ + return self.vocab.get(token, self.vocab.get(self.unk_token)) - def convert_ids_to_tokens(self, ids): - """Converts a sequence of ids in wordpiece tokens using the vocab.""" - tokens = [] - for i in ids: - tokens.append(self.ids_to_tokens[i]) - return tokens + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (string/unicode) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) - def encode(self, text): - return self.convert_tokens_to_ids(self.tokenize(text)) - - def decode(self, token_ids, clean_up_tokenization_spaces=True): + def _convert_ids_to_string(self, tokens_ids): """Converts a sequence of ids in a string.""" - tokens = self.convert_ids_to_tokens(token_ids) + tokens = self.convert_ids_to_tokens(tokens_ids) out_string = ''.join(tokens).replace(' ##', '').strip() - if clean_up_tokenization_spaces: - for special_tok in (self.UNK_TOKEN, self.SEP_TOKEN, self.PAD_TOKEN, self.CLS_TOKEN, self.MASK_TOKEN): - out_string = out_string.replace(special_tok, '') - out_string = clean_up_tokenization(out_string) return out_string def save_vocabulary(self, vocab_path): @@ -245,17 +191,20 @@ class BasicTokenizer(object): def __init__(self, do_lower_case=True, - never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + never_split=None): """Constructs a BasicTokenizer. Args: do_lower_case: Whether to lower case the input. """ + if never_split is None: + never_split = [] self.do_lower_case = do_lower_case self.never_split = never_split - def tokenize(self, text): + def tokenize(self, text, never_split=None): """Tokenizes a piece of text.""" + never_split = self.never_split + (never_split if never_split is not None else []) text = self._clean_text(text) # This was added on November 1st, 2018 for the multilingual and Chinese # models. This is also applied to the English models now, but it doesn't @@ -267,7 +216,7 @@ class BasicTokenizer(object): orig_tokens = whitespace_tokenize(text) split_tokens = [] for token in orig_tokens: - if self.do_lower_case and token not in self.never_split: + if self.do_lower_case and token not in never_split: token = token.lower() token = self._run_strip_accents(token) split_tokens.extend(self._run_split_on_punc(token)) @@ -286,9 +235,9 @@ class BasicTokenizer(object): output.append(char) return "".join(output) - def _run_split_on_punc(self, text): + def _run_split_on_punc(self, text, never_split=None): """Splits punctuation on a piece of text.""" - if text in self.never_split: + if never_split is not None and text in never_split: return [text] chars = list(text) i = 0 @@ -360,7 +309,7 @@ class BasicTokenizer(object): class WordpieceTokenizer(object): """Runs WordPiece tokenization.""" - def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): self.vocab = vocab self.unk_token = unk_token self.max_input_chars_per_word = max_input_chars_per_word diff --git a/pytorch_transformers/tokenization_gpt2.py b/pytorch_transformers/tokenization_gpt2.py index abdfe39c1cd..af1ad2cf8f6 100644 --- a/pytorch_transformers/tokenization_gpt2.py +++ b/pytorch_transformers/tokenization_gpt2.py @@ -38,7 +38,6 @@ logger = logging.getLogger(__name__) VOCAB_FILES_NAMES = { 'vocab_file': 'vocab.json', 'merges_file': 'merges.txt', - 'special_tokens_file': 'special_tokens.txt' } PRETRAINED_VOCAB_FILES_MAP = { @@ -52,11 +51,6 @@ PRETRAINED_VOCAB_FILES_MAP = { 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", }, - 'special_tokens_file': - { - 'gpt2': None, - 'gpt2-medium': None, - } } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { @@ -108,8 +102,10 @@ class GPT2Tokenizer(PreTrainedTokenizer): pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, errors='replace', max_len=None): - self.max_len = max_len if max_len is not None else int(1e12) + def __init__(self, vocab_file, merges_file, errors='replace', + bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs): + super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, **kwargs) + self.encoder = json.load(open(vocab_file)) self.decoder = {v:k for k,v in self.encoder.items()} self.errors = errors # how to handle errors in decoding @@ -123,32 +119,9 @@ class GPT2Tokenizer(PreTrainedTokenizer): # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") - all_special_tokens = [] - if special_tokens_file is not None: - special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] - all_special_tokens.extend(special_tokens_to_add) - if special_tokens is not None and special_tokens: - all_special_tokens.extend(special_tokens) - - self.special_tokens = {} - self.special_tokens_decoder = {} - self.set_special_tokens(all_special_tokens) - - def __len__(self): - return len(self.encoder) + len(self.special_tokens) - - def set_special_tokens(self, special_tokens): - """ Add a list of additional tokens to the encoder. - The additional tokens are indexed starting from the last index of the - current vocabulary in the order of the `special_tokens` list. - """ - if not special_tokens: - self.special_tokens = {} - self.special_tokens_decoder = {} - return - self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) - self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} - logger.info("Special tokens {}".format(self.special_tokens)) + @property + def vocab_size(self): + return len(self.encoder) def bpe(self, token): if token in self.cache: @@ -191,7 +164,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): self.cache[token] = word return word - def tokenize(self, text): + def _tokenize(self, text): """ Tokenize a string. """ bpe_tokens = [] for token in re.findall(self.pat, text): @@ -202,57 +175,27 @@ class GPT2Tokenizer(PreTrainedTokenizer): bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) return bpe_tokens - def convert_tokens_to_ids(self, tokens): - """ Converts a sequence of tokens into ids using the vocab. """ - ids = [] - if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): - if tokens in self.special_tokens: - return self.special_tokens[tokens] - else: - return self.encoder.get(tokens, 0) - for token in tokens: - if token in self.special_tokens: - ids.append(self.special_tokens[token]) - else: - ids.append(self.encoder.get(token, 0)) - if len(ids) > self.max_len: - logger.warning( - "Token indices sequence length is longer than the specified maximum " - " sequence length for this OpenAI GPT model ({} > {}). Running this" - " sequence through the model will result in indexing errors".format(len(ids), self.max_len) - ) - return ids + def _convert_token_to_id(self, token): + """ Converts a token (str/unicode) in an id using the vocab. """ + return self.encoder.get(token, self.encoder.get(self.unk_token)) - def convert_ids_to_tokens(self, ids, skip_special_tokens=False): - """Converts a sequence of ids in BPE tokens using the vocab.""" - tokens = [] - for i in ids: - if i in self.special_tokens_decoder: - if not skip_special_tokens: - tokens.append(self.special_tokens_decoder[i]) - else: - tokens.append(self.decoder[i]) - return tokens + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (string/unicode) using the vocab.""" + return self.decoder.get(index, self.unk_token) - def encode(self, text): - return self.convert_tokens_to_ids(self.tokenize(text)) - - def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True): - text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens)) + def _convert_ids_to_string(self, tokens_ids): + """Converts a sequence of ids in a string.""" + text = ''.join(tokens_ids) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) - if clean_up_tokenization_spaces: - text = text.replace('', '') - text = clean_up_tokenization(text) return text - def save_vocabulary(self, vocab_path): + def save_vocabulary(self, save_directory): """Save the tokenizer vocabulary and merge files to a directory.""" - if not os.path.isdir(vocab_path): - logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + if not os.path.isdir(save_directory): + logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) return - vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) - merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file']) - special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file']) + vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) + merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) with open(vocab_file, 'w', encoding='utf-8') as f: f.write(json.dumps(self.encoder, ensure_ascii=False)) @@ -268,14 +211,4 @@ class GPT2Tokenizer(PreTrainedTokenizer): writer.write(' '.join(bpe_tokens) + u'\n') index += 1 - index = len(self.encoder) - with open(special_tokens_file, 'w', encoding='utf-8') as writer: - for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): - if index != token_index: - logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." - " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) - index = token_index - writer.write(token + u'\n') - index += 1 - - return vocab_file, merge_file, special_tokens_file + return vocab_file, merge_file diff --git a/pytorch_transformers/tokenization_openai.py b/pytorch_transformers/tokenization_openai.py index 419dfdad921..16d355c57d9 100644 --- a/pytorch_transformers/tokenization_openai.py +++ b/pytorch_transformers/tokenization_openai.py @@ -20,13 +20,9 @@ import json import logging import os import re -import sys from io import open -from tqdm import tqdm - -from .file_utils import cached_path -from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization +from .tokenization_utils import PreTrainedTokenizer from .tokenization_bert import BasicTokenizer logger = logging.getLogger(__name__) @@ -34,7 +30,6 @@ logger = logging.getLogger(__name__) VOCAB_FILES_NAMES = { 'vocab_file': 'vocab.json', 'merges_file': 'merges.txt', - 'special_tokens_file': 'special_tokens.txt' } PRETRAINED_VOCAB_FILES_MAP = { @@ -46,10 +41,6 @@ PRETRAINED_VOCAB_FILES_MAP = { { 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", }, - 'special_tokens_file': - { - 'openai-gpt': None, - } } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { @@ -88,14 +79,14 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): BPE tokenizer. Peculiarities: - lower case all inputs - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. - - argument special_tokens and function set_special_tokens: - can be used to add additional symbols (ex: "__classify__") to a vocabulary. """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, max_len=None): + def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): + super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs) + try: import ftfy import spacy @@ -103,11 +94,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): self.fix_text = ftfy.fix_text except ImportError: logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") - self.nlp = BasicTokenizer(do_lower_case=True, - never_split=special_tokens if special_tokens is not None else []) + self.nlp = BasicTokenizer(do_lower_case=True) self.fix_text = None - self.max_len = max_len if max_len is not None else int(1e12) self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.decoder = {v:k for k,v in self.encoder.items()} merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] @@ -115,35 +104,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {} - all_special_tokens = [] - if special_tokens_file is not None: - special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] - all_special_tokens.extend(special_tokens_to_add) - if special_tokens is not None and special_tokens: - all_special_tokens.extend(special_tokens) - - self.special_tokens = {} - self.special_tokens_decoder = {} - self.set_special_tokens(all_special_tokens) - - def __len__(self): - return len(self.encoder) + len(self.special_tokens) - - def set_special_tokens(self, special_tokens): - """ Add a list of additional tokens to the encoder. - The additional tokens are indexed starting from the last index of the - current vocabulary in the order of the `special_tokens` list. - """ - if not special_tokens: - self.special_tokens = {} - self.special_tokens_decoder = {} - return - self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) - self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} - if self.fix_text is None: - # Using BERT's BasicTokenizer: we can update the tokenizer - self.nlp.never_split = special_tokens - logger.info("Special tokens {}".format(self.special_tokens)) + @property + def vocab_size(self): + return len(self.encoder) def bpe(self, token): word = tuple(token[:-1]) + (token[-1] + '',) @@ -188,7 +151,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): self.cache[token] = word return word - def tokenize(self, text): + def _tokenize(self, text): """ Tokenize a string. """ split_tokens = [] if self.fix_text is None: @@ -203,58 +166,26 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) return split_tokens - def convert_tokens_to_ids(self, tokens): - """ Converts a sequence of tokens into ids using the vocab. """ - ids = [] - if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): - if tokens in self.special_tokens: - return self.special_tokens[tokens] - else: - return self.encoder.get(tokens, 0) - for token in tokens: - if token in self.special_tokens: - ids.append(self.special_tokens[token]) - else: - ids.append(self.encoder.get(token, 0)) - if len(ids) > self.max_len: - logger.warning( - "Token indices sequence length is longer than the specified maximum " - " sequence length for this OpenAI GPT model ({} > {}). Running this" - " sequence through the model will result in indexing errors".format(len(ids), self.max_len) - ) - return ids + def _convert_token_to_id(self, token): + """ Converts a token (str/unicode) in an id using the vocab. """ + return self.encoder.get(token, self.encoder.get(self.unk_token)) - def convert_ids_to_tokens(self, ids, skip_special_tokens=False): - """Converts a sequence of ids in BPE tokens using the vocab.""" - tokens = [] - for i in ids: - if i in self.special_tokens_decoder: - if not skip_special_tokens: - tokens.append(self.special_tokens_decoder[i]) - else: - tokens.append(self.decoder[i]) - return tokens + def _convert_id_to_token(self, index): + """Converts an id in a token (BPE) using the vocab.""" + return self.decoder.get(index, self.unk_token) - def encode(self, text): - return self.convert_tokens_to_ids(self.tokenize(text)) - - def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): + def _convert_ids_to_string(self, tokens_ids): """Converts a sequence of ids in a string.""" - tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) - out_string = ''.join(tokens).replace('', ' ').strip() - if clean_up_tokenization_spaces: - out_string = out_string.replace('', '') - out_string = clean_up_tokenization(out_string) + out_string = ''.join(tokens_ids).replace('', ' ').strip() return out_string - def save_vocabulary(self, vocab_path): + def save_vocabulary(self, save_directory): """Save the tokenizer vocabulary and merge files to a directory.""" - if not os.path.isdir(vocab_path): - logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + if not os.path.isdir(save_directory): + logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) return - vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) - merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file']) - special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file']) + vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) + merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) with open(vocab_file, 'w', encoding='utf-8') as f: f.write(json.dumps(self.encoder, ensure_ascii=False)) @@ -270,14 +201,4 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): writer.write(' '.join(bpe_tokens) + u'\n') index += 1 - index = len(self.encoder) - with open(special_tokens_file, 'w', encoding='utf-8') as writer: - for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): - if index != token_index: - logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." - " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) - index = token_index - writer.write(token + u'\n') - index += 1 - - return vocab_file, merge_file, special_tokens_file + return vocab_file, merge_file diff --git a/pytorch_transformers/tokenization_transfo_xl.py b/pytorch_transformers/tokenization_transfo_xl.py index a86c8fe4600..0b4e8c0ca5e 100644 --- a/pytorch_transformers/tokenization_transfo_xl.py +++ b/pytorch_transformers/tokenization_transfo_xl.py @@ -41,7 +41,7 @@ else: logger = logging.getLogger(__name__) -VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin'} +VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin', 'vocab_file': 'vocab.txt'} PRETRAINED_VOCAB_FILES_MAP = { 'pretrained_vocab_file': @@ -67,9 +67,17 @@ class TransfoXLTokenizer(PreTrainedTokenizer): pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False, + def __init__(self, special=None, min_freq=0, max_size=None, lower_case=False, delimiter=None, vocab_file=None, pretrained_vocab_file=None, - never_split=("", "", "")): + never_split=None, unk_token="", eos_token="", + additional_special_tokens=[""], **kwargs): + super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token, + additional_special_tokens=additional_special_tokens, + **kwargs) + if never_split is None: + never_split = self.all_special_tokens + if special is None: + special = [] self.counter = Counter() self.special = special self.min_freq = min_freq @@ -200,11 +208,13 @@ class TransfoXLTokenizer(PreTrainedTokenizer): self.idx2sym.append(sym) self.sym2idx[sym] = len(self.idx2sym) - 1 - def get_sym(self, idx): + def _convert_id_to_token(self, idx): + """Converts an id in a token (BPE) using the vocab.""" assert 0 <= idx < len(self), 'Index {} out of vocabulary range'.format(idx) return self.idx2sym[idx] - def get_idx(self, sym): + def _convert_token_to_id(self, sym): + """ Converts a token (str/unicode) in an id using the vocab. """ if sym in self.sym2idx: return self.sym2idx[sym] else: @@ -220,36 +230,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer): else: raise ValueError('Token not in vocabulary and no token in vocabulary for replacement') - def convert_ids_to_tokens(self, indices): - """Converts a sequence of indices in symbols using the vocab.""" - return [self.get_sym(idx) for idx in indices] - - def convert_tokens_to_ids(self, symbols): - """Converts a sequence of symbols into ids using the vocab.""" - return [self.get_idx(sym) for sym in symbols] + def _convert_ids_to_string(self, tokens_ids): + """Converts a sequence of ids in a string.""" + out_string = ' '.join(tokens_ids).strip() + return out_string def convert_to_tensor(self, symbols): return torch.LongTensor(self.convert_tokens_to_ids(symbols)) - def encode(self, text): - return self.convert_tokens_to_ids(self.tokenize(text)) - - def decode(self, indices, exclude=None, clean_up_tokenization_spaces=True): - """Converts a sequence of indices in a string.""" - if exclude is None: - out_string = ' '.join([self.get_sym(idx) for idx in indices]) - else: - out_string = ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) - - if clean_up_tokenization_spaces: - out_string = clean_up_tokenization(out_string) - - return out_string - - def __len__(self): + @property + def vocab_size(self): return len(self.idx2sym) - def tokenize(self, line, add_eos=False, add_double_eos=False): + def _tokenize(self, line, add_eos=False, add_double_eos=False): line = line.strip() # convert to lower case if self.lower_case: diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 90043156572..b191dd22e6e 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -16,37 +16,145 @@ from __future__ import (absolute_import, division, print_function, unicode_literals) -import sys -import json import logging import os -import regex as re +import json +import six from io import open -try: - from functools import lru_cache -except ImportError: - # Just a dummy decorator to get the checks to run on python2 - # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. - def lru_cache(): - return lambda func: func - from .file_utils import cached_path logger = logging.getLogger(__name__) +SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json' +ADDED_TOKENS_FILE = 'added_tokens.json' class PreTrainedTokenizer(object): - """ An abstract class to handle dowloading and loading pretrained tokenizers. + """ An abstract class to handle dowloading and loading pretrained tokenizers and adding tokens to the vocabulary. + + Derived class can set up a few special tokens to be used in common scripts and internals: + bos_token, eos_token, EOP_TOKEN, EOD_TOKEN, unk_token, sep_token, pad_token, cls_token, mask_token + additional_special_tokens = [] + + We defined an added_tokens_encoder to add new tokens to the vocabulary without having to handle the + specific vocabulary augmentation methods of the various underlying dictionnary structures (BPE, sentencepiece...). """ vocab_files_names = {} pretrained_vocab_files_map = {} max_model_input_sizes = {} + SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token", + "pad_token", "cls_token", "mask_token", + "additional_special_tokens"] + + @property + def bos_token(self): + if self._bos_token is None: + logger.error("Using bos_token, but it is not set yet.") + return self._bos_token + + @property + def eos_token(self): + if self._eos_token is None: + logger.error("Using eos_token, but it is not set yet.") + return self._eos_token + + @property + def unk_token(self): + if self._unk_token is None: + logger.error("Using unk_token, but it is not set yet.") + return self._unk_token + + @property + def sep_token(self): + if self._sep_token is None: + logger.error("Using sep_token, but it is not set yet.") + return self._sep_token + + @property + def pad_token(self): + if self._pad_token is None: + logger.error("Using pad_token, but it is not set yet.") + return self._pad_token + + @property + def cls_token(self): + if self._cls_token is None: + logger.error("Using cls_token, but it is not set yet.") + return self._cls_token + + @property + def mask_token(self): + if self._mask_token is None: + logger.error("Using mask_token, but it is not set yet.") + return self._mask_token + + @property + def additional_special_tokens(self): + if self._additional_special_tokens is None: + logger.error("Using additional_special_tokens, but it is not set yet.") + return self._additional_special_tokens + + @bos_token.setter + def bos_token(self, value): + self._bos_token = value + + @eos_token.setter + def eos_token(self, value): + self._eos_token = value + + @unk_token.setter + def unk_token(self, value): + self._unk_token = value + + @sep_token.setter + def sep_token(self, value): + self._sep_token = value + + @pad_token.setter + def pad_token(self, value): + self._pad_token = value + + @cls_token.setter + def cls_token(self, value): + self._cls_token = value + + @mask_token.setter + def mask_token(self, value): + self._mask_token = value + + @additional_special_tokens.setter + def additional_special_tokens(self, value): + self._additional_special_tokens = value + + def __init__(self, max_len=None, **kwargs): + self._bos_token = None + self._eos_token = None + self._unk_token = None + self._sep_token = None + self._pad_token = None + self._cls_token = None + self._mask_token = None + self._additional_special_tokens = [] + + self.max_len = max_len if max_len is not None else int(1e12) + self.added_tokens_encoder = {} + self.added_tokens_decoder = {} + + for key, value in kwargs.items(): + if key not in self.SPECIAL_TOKENS_ATTRIBUTES: + raise ValueError( + "PreTrainedTokenizer.__init__() argument {} should be in {}".format( + key, ', '.join(self.SPECIAL_TOKENS_ATTRIBUTES))) + else: + setattr(self, key, value) + + @classmethod def from_pretrained(cls, *inputs, **kwargs): return cls._from_pretrained(*inputs, **kwargs) + @classmethod def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): """ @@ -59,16 +167,20 @@ class PreTrainedTokenizer(object): for file_id, map_list in cls.pretrained_vocab_files_map.items(): vocab_files[file_id] = map_list[pretrained_model_name_or_path] else: - for file_id, file_name in cls.vocab_files_names.items(): + all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE, + 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE} + all_vocab_files_names.update(cls.vocab_files_names) + for file_id, file_name in all_vocab_files_names.items(): if os.path.isdir(pretrained_model_name_or_path): full_file_name = os.path.join(pretrained_model_name_or_path, file_name) else: full_file_name = pretrained_model_name_or_path if not os.path.exists(full_file_name): - logger.info("Didn't find file {}. We don't load it.".format(full_file_name)) + logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) full_file_name = None vocab_files[file_id] = full_file_name - # redirect to the cache, if necessary + + # Get files from url, cache, or disk depending on the case try: resolved_vocab_files = {} for file_id, file_path in vocab_files.items(): @@ -95,6 +207,7 @@ class PreTrainedTokenizer(object): logger.info("loading file {} from cache at {}".format( file_path, resolved_vocab_files[file_id])) + # Set max length if needed if pretrained_model_name_or_path in cls.max_model_input_sizes: # if we're using a pretrained model, ensure the tokenizer # wont index sequences longer than the number of positional embeddings @@ -102,31 +215,255 @@ class PreTrainedTokenizer(object): kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Merge resolved_vocab_files arguments in kwargs. + added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None) + special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None) for args_name, file_path in resolved_vocab_files.items(): - kwargs[args_name] = file_path + if args_name not in kwargs: + kwargs[args_name] = file_path + if special_tokens_map_file is not None: + special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8")) + for key, value in special_tokens_map.items(): + if key not in kwargs: + kwargs[key] = value # Instantiate tokenizer. tokenizer = cls(*inputs, **kwargs) + # Add supplementary tokens. + if added_tokens_file is not None: + added_tokens = json.load(open(added_tokens_file, encoding="utf-8")) + added_tok_encoder = dict((tok, len(tokenizer) + i) for i, tok in enumerate(added_tokens)) + added_tok_decoder = {v:k for k, v in added_tok_encoder.items()} + tokenizer.added_tokens_encoder.update(added_tok_encoder) + tokenizer.added_tokens_decoder.update(added_tok_decoder) + return tokenizer - def tokenize(self, text): + + def save_pretrained(self, save_directory): + """ Save the tokenizer vocabulary files (with added tokens) and the + special-tokens-to-class-attributes-mapping to a directory, so that it + can be re-loaded using the `from_pretrained(save_directory)` class method. + """ + if not os.path.isdir(save_directory): + logger.error("Saving directory ({}) should be a directory".format(save_directory)) + return + + special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE) + added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE) + + with open(special_tokens_map_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) + + with open(added_tokens_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(self.added_tokens_decoder, ensure_ascii=False)) + + vocab_files = self.save_vocabulary(save_directory) + + return vocab_files + (special_tokens_map_file, added_tokens_file) + + + def save_vocabulary(self, save_directory): + """ Save the tokenizer vocabulary to a directory. This method doesn't save added tokens + and special token mappings. + + Please use `save_pretrained()` to save the full Tokenizer state so that it can be + reloaded using the `from_pretrained(save_directory)` class method. + """ + raise NotImplementedError + + + def vocab_size(self): + raise NotImplementedError + + + def __len__(self): + return self.vocab_size + len(self.added_tokens_encoder) + + + def add_tokens(self, new_tokens): + """ Add a list of new tokens to the tokenizer class. If the new tokens are not in the + vocabulary, they are added to the added_tokens_encoder with indices starting from + the last index of the current vocabulary. + + Returns: + Number of tokens added to the vocabulary which can be used to correspondingly + increase the size of the associated model embedding matrices. + """ + if not new_tokens: + return 0 + + to_add_tokens = [] + for token in new_tokens: + if self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token): + to_add_tokens.append(token) + logger.info("Adding %s to the vocabulary", token) + + added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens)) + added_tok_decoder = {v:k for k, v in added_tok_encoder.items()} + self.added_tokens_encoder.update(added_tok_encoder) + self.added_tokens_decoder.update(added_tok_decoder) + + return len(to_add_tokens) + + + def add_special_tokens(self, special_tokens_dict): + """ Add a dictionnary of special tokens (eos, pad, cls...) to the encoder and link them + to class attributes. If the special tokens are not in the vocabulary, they are added + to it and indexed starting from the last index of the current vocabulary. + + Returns: + Number of tokens added to the vocabulary which can be used to correspondingly + increase the size of the associated model embedding matrices. + """ + if not special_tokens_dict: + return 0 + + added_special_tokens = self.add_tokens(special_tokens_dict.values()) + for key, value in special_tokens_dict.items(): + logger.info("Assigning %s to the %s key of the tokenizer", value, key) + setattr(self, key, value) + + return added_special_tokens + + + def tokenize(self, text, **kwargs): + """ Converts a string in a sequence of tokens (string), using the tokenizer. + Split in words for word-based vocabulary or sub-words for sub-word-based + vocabularies (BPE/SentencePieces/WordPieces). + + Take care of added tokens. + """ + def split_on_tokens(tok_list, text): + if not text: + return [] + if not tok_list: + return self._tokenize(text, **kwargs) + tok = tok_list[0] + split_text = text.split(tok) + return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \ + for sub_text in split_text), [])[:-1] + + added_tokens = list(self.added_tokens_encoder.keys()) + tokenized_text = split_on_tokens(added_tokens, text) + return tokenized_text + + def _tokenize(self, text, **kwargs): + """ Converts a string in a sequence of tokens (string), using the tokenizer. + Split in words for word-based vocabulary or sub-words for sub-word-based + vocabularies (BPE/SentencePieces/WordPieces). + + Don't take care of added tokens. + """ raise NotImplementedError def convert_tokens_to_ids(self, tokens): + """ Converts a single token or a sequence of tokens (str/unicode) in a integer id + (resp.) a sequence of ids, using the vocabulary. + """ + if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)): + return self.convert_token_to_id_with_added_voc(tokens) + + ids = [] + for token in tokens: + ids.append(self.convert_token_to_id_with_added_voc(token)) + if len(ids) > self.max_len: + logger.warning("Token indices sequence length is longer than the specified maximum sequence length " + "for this model ({} > {}). Running this sequence through the model will result in " + "indexing errors".format(len(ids), self.max_len)) + return ids + + + def convert_token_to_id_with_added_voc(self, token): + if token in self.added_tokens_encoder: + return self.added_tokens_encoder[token] + return self._convert_token_to_id(token) + + + def _convert_token_to_id(self, token): raise NotImplementedError - def convert_ids_to_tokens(self, ids): + + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """ Converts a single index or a sequence of indices (integers) in a token " + (resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens. + + Args: + skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False + """ + if isinstance(ids, int): + return self.convert_id_to_token(ids) + tokens = [] + for index in ids: + if index in self.all_special_ids and skip_special_tokens: + continue + if index in self.added_tokens_decoder: + tokens.append(self.added_tokens_decoder[index]) + else: + tokens.append(self._convert_id_to_token(index)) + return tokens + + + def _convert_id_to_token(self, index): raise NotImplementedError + def encode(self, text): - raise NotImplementedError + """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. + same as self.convert_tokens_to_ids(self.tokenize(text)). + """ + return self.convert_tokens_to_ids(self.tokenize(text)) - def decode(self, token_ids, *input, **kwargs): - raise NotImplementedError - def save_vocabulary(self, vocab_path): - raise NotImplementedError + def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): + """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary + with options to remove special tokens and clean up tokenization spaces. + """ + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + text = self._convert_ids_to_string(filtered_tokens) + if clean_up_tokenization_spaces: + text = clean_up_tokenization(text) + return text + + def _convert_ids_to_string(self, tokens_ids): + """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary. + roughtly same as ' '.join(self.convert_ids_to_tokens(token_ids)). + """ + return ' '.join(self.convert_ids_to_tokens(tokens_ids)) + + @property + def special_tokens_map(self): + """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their + values ('', ''...) + """ + set_attr = {} + for attr in self.SPECIAL_TOKENS_ATTRIBUTES: + attr_value = getattr(self, "_" + attr) + if attr_value: + set_attr[attr] = attr_value + return set_attr + + @property + def all_special_tokens(self): + """ List all the special tokens ('', ''...) mapped to class attributes + (cls_token, unk_token...). + """ + all_toks = [] + set_attr = self.special_tokens_map + for attr_value in set_attr.values(): + all_toks = all_toks + (attr_value if isinstance(attr_value, (list, tuple)) else [attr_value]) + all_toks = list(set(all_toks)) + return all_toks + + @property + def all_special_ids(self): + """ List the vocabulary indices of the special tokens ('', ''...) mapped to + class attributes (cls_token, unk_token...). + """ + all_toks = self.all_special_tokens + all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks) + return all_ids + def clean_up_tokenization(out_string): diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index e37f3888a30..8a11a84f8c6 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -34,7 +34,6 @@ logger = logging.getLogger(__name__) VOCAB_FILES_NAMES = { 'vocab_file': 'vocab.json', 'merges_file': 'merges.txt', - 'special_tokens_file': 'special_tokens.txt' } PRETRAINED_VOCAB_FILES_MAP = { @@ -46,24 +45,12 @@ PRETRAINED_VOCAB_FILES_MAP = { { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt", }, - 'special_tokens_file': - { - 'xlm-mlm-en-2048': None, - } } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'xlm-mlm-en-2048': 512, } -INDEX = { - "bos_index": 0, - "eos_index": 1, - "pad_index": 2, - "unk_index": 3, - "mask_index": 5 -} - def get_pairs(word): """ Return set of symbol pairs in a word. @@ -103,7 +90,16 @@ class XLMTokenizer(PreTrainedTokenizer): pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, max_len=None): + def __init__(self, vocab_file, merges_file, unk_token="", bos_token="", + sep_token="", pad_token="", cls_token="", + mask_token="", additional_special_tokens=["", + "", "", "", "", "", + "", "", "", ""], **kwargs): + super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token, + sep_token=sep_token, pad_token=pad_token, + cls_token=cls_token, mask_token=mask_token, + additional_special_tokens=additional_special_tokens, + **kwargs) try: import ftfy import spacy @@ -111,11 +107,9 @@ class XLMTokenizer(PreTrainedTokenizer): self.fix_text = ftfy.fix_text except ImportError: logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") - self.nlp = BasicTokenizer(do_lower_case=True, - never_split=special_tokens if special_tokens is not None else []) + self.nlp = BasicTokenizer(do_lower_case=True) self.fix_text = None - self.max_len = max_len if max_len is not None else int(1e12) self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.decoder = {v:k for k,v in self.encoder.items()} merges = open(merges_file, encoding='utf-8').read().split('\n')[:-1] @@ -123,35 +117,9 @@ class XLMTokenizer(PreTrainedTokenizer): self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {} - all_special_tokens = [] - if special_tokens_file is not None: - special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] - all_special_tokens.extend(special_tokens_to_add) - if special_tokens is not None and special_tokens: - all_special_tokens.extend(special_tokens) - - self.special_tokens = {} - self.special_tokens_decoder = {} - self.set_special_tokens(all_special_tokens) - - def __len__(self): - return len(self.encoder) + len(self.special_tokens) - - def set_special_tokens(self, special_tokens): - """ Add a list of additional tokens to the encoder. - The additional tokens are indexed starting from the last index of the - current vocabulary in the order of the `special_tokens` list. - """ - if not special_tokens: - self.special_tokens = {} - self.special_tokens_decoder = {} - return - self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) - self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} - if self.fix_text is None: - # Using BERT's BasicTokenizer: we can update the tokenizer - self.nlp.never_split = special_tokens - logger.info("Special tokens {}".format(self.special_tokens)) + @property + def vocab_size(self): + return len(self.encoder) def bpe(self, token): word = tuple(token[:-1]) + (token[-1] + '',) @@ -196,7 +164,7 @@ class XLMTokenizer(PreTrainedTokenizer): self.cache[token] = word return word - def tokenize(self, text): + def _tokenize(self, text): """ Tokenize a string. """ split_tokens = [] if self.fix_text is None: @@ -211,58 +179,26 @@ class XLMTokenizer(PreTrainedTokenizer): split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) return split_tokens - def convert_tokens_to_ids(self, tokens): - """ Converts a sequence of tokens into ids using the vocab. """ - ids = [] - if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): - if tokens in self.special_tokens: - return self.special_tokens[tokens] - else: - return self.encoder.get(tokens, 0) - for token in tokens: - if token in self.special_tokens: - ids.append(self.special_tokens[token]) - else: - ids.append(self.encoder.get(token, 0)) - if len(ids) > self.max_len: - logger.warning( - "Token indices sequence length is longer than the specified maximum " - " sequence length for this OpenAI GPT model ({} > {}). Running this" - " sequence through the model will result in indexing errors".format(len(ids), self.max_len) - ) - return ids + def _convert_token_to_id(self, token): + """ Converts a token (str/unicode) in an id using the vocab. """ + return self.encoder.get(token, self.encoder.get(self.unk_token)) - def convert_ids_to_tokens(self, ids, skip_special_tokens=False): - """Converts a sequence of ids in BPE tokens using the vocab.""" - tokens = [] - for i in ids: - if i in self.special_tokens_decoder: - if not skip_special_tokens: - tokens.append(self.special_tokens_decoder[i]) - else: - tokens.append(self.decoder[i]) - return tokens + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (string/unicode) using the vocab.""" + return self.decoder.get(index, self.unk_token) - def encode(self, text): - return self.convert_tokens_to_ids(self.tokenize(text)) - - def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): + def _convert_ids_to_string(self, tokens_ids): """Converts a sequence of ids in a string.""" - tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) - out_string = ''.join(tokens).replace('', ' ').strip() - if clean_up_tokenization_spaces: - out_string = out_string.replace('', '') - out_string = clean_up_tokenization(out_string) + out_string = ''.join(tokens_ids).replace('', ' ').strip() return out_string - def save_vocabulary(self, vocab_path): + def save_vocabulary(self, save_directory): """Save the tokenizer vocabulary and merge files to a directory.""" - if not os.path.isdir(vocab_path): - logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + if not os.path.isdir(save_directory): + logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) return - vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) - merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file']) - special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file']) + vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) + merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) with open(vocab_file, 'w', encoding='utf-8') as f: f.write(json.dumps(self.encoder, ensure_ascii=False)) @@ -277,14 +213,4 @@ class XLMTokenizer(PreTrainedTokenizer): writer.write(' '.join(bpe_tokens) + u'\n') index += 1 - index = len(self.encoder) - with open(special_tokens_file, 'w', encoding='utf-8') as writer: - for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): - if index != token_index: - logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." - " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) - index = token_index - writer.write(token + u'\n') - index += 1 - - return vocab_file, merge_file, special_tokens_file + return vocab_file, merge_file diff --git a/pytorch_transformers/tokenization_xlnet.py b/pytorch_transformers/tokenization_xlnet.py index a30e6db8da0..942b532ec60 100644 --- a/pytorch_transformers/tokenization_xlnet.py +++ b/pytorch_transformers/tokenization_xlnet.py @@ -16,17 +16,13 @@ from __future__ import (absolute_import, division, print_function, unicode_literals) -import json import logging import os -import sys from shutil import copyfile -from io import open import unicodedata import six -from .file_utils import cached_path from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization logger = logging.getLogger(__name__) @@ -44,8 +40,6 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'xlnet-large-cased': 512, } -VOCAB_NAME = 'spiece.model' - SPIECE_UNDERLINE = u'▁' # Segments (not really needed) @@ -60,31 +54,26 @@ class XLNetTokenizer(PreTrainedTokenizer): SentencePiece based tokenizer. Peculiarities: - requires SentencePiece: https://github.com/google/sentencepiece """ - # Tokens - special_symbols = { - "" : 0, - "" : 1, - "" : 2, - "" : 3, - "" : 4, - "" : 5, - "" : 6, - "" : 7, - "" : 8, - } vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__(self, vocab_file, max_len=None, - do_lower_case=False, remove_space=True, keep_accents=False): + do_lower_case=False, remove_space=True, keep_accents=False, + bos_token="", eos_token="", unk_token="", sep_token="", + pad_token="", cls_token="", mask_token="", + additional_special_tokens=["", ""], **kwargs): + super(XLNetTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, + unk_token=unk_token, sep_token=sep_token, + pad_token=pad_token, cls_token=cls_token, + mask_token=mask_token, additional_special_tokens= + additional_special_tokens, **kwargs) try: import sentencepiece as spm except ImportError: logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" "pip install sentencepiece") - self.max_len = max_len if max_len is not None else int(1e12) self.do_lower_case = do_lower_case self.remove_space = remove_space self.keep_accents = keep_accents @@ -94,46 +83,7 @@ class XLNetTokenizer(PreTrainedTokenizer): self.sp_model.Load(vocab_file) @property - def UNK_TOKEN(self): - return "" - - @property - def SEP_TOKEN(self): - return "" - - @property - def PAD_TOKEN(self): - return "" - - @property - def CLS_TOKEN(self): - return "" - - @property - def MASK_TOKEN(self): - return "" - - @property - def UNK_ID(self): - return self.special_symbols[""] - - @property - def SEP_ID(self): - return self.special_symbols[""] - - @property - def PAD_ID(self): - return self.special_symbols[""] - - @property - def CLS_ID(self): - return self.special_symbols[""] - - @property - def MASK_ID(self): - return self.special_symbols[""] - - def __len__(self): + def vocab_size(self): return len(self.sp_model) def __getstate__(self): @@ -169,7 +119,7 @@ class XLNetTokenizer(PreTrainedTokenizer): return outputs - def tokenize(self, text, return_unicode=True, sample=False): + def _tokenize(self, text, return_unicode=True, sample=False): """ Tokenize a string. return_unicode is used only for py2 """ @@ -208,56 +158,30 @@ class XLNetTokenizer(PreTrainedTokenizer): return new_pieces - def convert_tokens_to_ids(self, tokens, sample=False): - """ Converts a sequence of tokens into ids using the vocab. """ - ids = [] - if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): - return self.sp_model.PieceToId(tokens) - for token in tokens: - ids.append(self.sp_model.PieceToId(token)) - if len(ids) > self.max_len: - logger.warning( - "Token indices sequence length is longer than the specified maximum " - " sequence length for this XLNet model ({} > {}). Running this" - " sequence through the model will result in indexing errors".format(len(ids), self.max_len) - ) - return ids + def _convert_token_to_id(self, token): + """ Converts a token (str/unicode) in an id using the vocab. """ + return self.sp_model.PieceToId(token) - def convert_ids_to_tokens(self, ids, return_unicode=True): - """Converts a sequence of ids in tokens.""" - tokens = [] - for i in ids: - tokens.append(self.sp_model.IdToPiece(i)) + def _convert_id_to_token(self, index, return_unicode=True): + """Converts an index (integer) in a token (string/unicode) using the vocab.""" + token = self.sp_model.IdToPiece(index) + if six.PY2 and return_unicode and isinstance(token, str): + token = token.decode('utf-8') + return token - if six.PY2 and return_unicode: - ret_pieces = [] - for piece in tokens: - if isinstance(piece, str): - piece = piece.decode('utf-8') - ret_pieces.append(piece) - tokens = ret_pieces - return tokens - - def encode(self, text, sample=False): - return self.convert_tokens_to_ids(self.tokenize(text, sample=sample)) - - def decode(self, ids, clean_up_tokenization_spaces=True): + def _convert_ids_to_string(self, tokens_ids): """Converts a sequence of ids in a string.""" - tokens = self.convert_ids_to_tokens(ids) - out_string = ''.join(tokens) - if clean_up_tokenization_spaces: - out_string = out_string.strip().replace('', '') - out_string = clean_up_tokenization(out_string) + out_string = ''.join(tokens_ids) return out_string - def save_vocabulary(self, vocab_path): + def save_vocabulary(self, save_directory): """ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory. """ - if not os.path.isdir(vocab_path): - logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + if not os.path.isdir(save_directory): + logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) return - out_vocab_file = os.path.join(vocab_path, VOCAB_NAME) + out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) copyfile(self.vocab_file, out_vocab_file)