From 36bca545ff1c13eb7af710d38af4270ef6a965ed Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 5 Jul 2019 15:02:59 +0200 Subject: [PATCH] tokenization abstract class - tests for examples --- examples/run_squad.py | 400 ++++++++++++++++++ examples/test_examples.py | 48 +++ pytorch_transformers/__init__.py | 6 +- pytorch_transformers/modeling_bert.py | 2 +- pytorch_transformers/modeling_gpt2.py | 2 +- pytorch_transformers/modeling_openai.py | 2 +- pytorch_transformers/modeling_transfo_xl.py | 2 +- .../{model_utils.py => modeling_utils.py} | 6 - pytorch_transformers/modeling_xlm.py | 2 +- pytorch_transformers/modeling_xlnet.py | 2 +- .../tests/model_utils_test.py | 50 --- .../tests/modeling_bert_test.py | 2 +- .../tests/modeling_gpt2_test.py | 2 +- .../tests/modeling_openai_test.py | 2 +- ...s_commons.py => modeling_tests_commons.py} | 0 .../tests/modeling_transfo_xl_test.py | 2 +- .../tests/modeling_utils_test.py | 9 +- .../tests/modeling_xlm_test.py | 2 +- .../tests/modeling_xlnet_test.py | 2 +- .../tests/tokenization_bert_test.py | 10 +- .../tests/tokenization_gpt2_test.py | 11 +- .../tests/tokenization_openai_test.py | 10 +- .../tests/tokenization_transfo_xl_test.py | 9 +- .../tests/tokenization_utils_test.py | 36 ++ .../tests/tokenization_xlm_test.py | 12 +- .../tests/tokenization_xlnet_test.py | 12 +- pytorch_transformers/tokenization_bert.py | 66 +-- pytorch_transformers/tokenization_gpt2.py | 117 ++--- pytorch_transformers/tokenization_openai.py | 110 ++--- .../tokenization_transfo_xl.py | 78 ++-- pytorch_transformers/tokenization_utils.py | 114 +++++ pytorch_transformers/tokenization_xlm.py | 122 ++---- pytorch_transformers/tokenization_xlnet.py | 131 ++---- 33 files changed, 815 insertions(+), 566 deletions(-) create mode 100644 examples/run_squad.py create mode 100644 examples/test_examples.py rename pytorch_transformers/{model_utils.py => modeling_utils.py} (98%) delete mode 100644 pytorch_transformers/tests/model_utils_test.py rename pytorch_transformers/tests/{model_tests_commons.py => modeling_tests_commons.py} (100%) rename examples/tests/examples_tests.py => pytorch_transformers/tests/modeling_utils_test.py (92%) create mode 100644 pytorch_transformers/tests/tokenization_utils_test.py create mode 100644 pytorch_transformers/tokenization_utils.py diff --git a/examples/run_squad.py b/examples/run_squad.py new file mode 100644 index 00000000000..d6d7279cb8a --- /dev/null +++ b/examples/run_squad.py @@ -0,0 +1,400 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Run BERT on SQuAD.""" + +from __future__ import absolute_import, division, print_function + +import argparse +import logging +import os +import random +import sys +from io import open + +import numpy as np +import torch +from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, + TensorDataset) +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm, trange + +from tensorboardX import SummaryWriter + +from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME +from pytorch_transformers.modeling_bert import BertForQuestionAnswering +from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule +from pytorch_transformers.tokenization_bert import BertTokenizer + +from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions + +if sys.version_info[0] == 2: + import cPickle as pickle +else: + import pickle + +logger = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser() + + ## Required parameters + parser.add_argument("--bert_model", default=None, type=str, required=True, + help="Bert pre-trained model selected in the list: bert-base-uncased, " + "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " + "bert-base-multilingual-cased, bert-base-chinese.") + parser.add_argument("--output_dir", default=None, type=str, required=True, + help="The output directory where the model checkpoints and predictions will be written.") + + ## Other parameters + parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json") + parser.add_argument("--predict_file", default=None, type=str, + help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") + parser.add_argument("--max_seq_length", default=384, type=int, + help="The maximum total input sequence length after WordPiece tokenization. Sequences " + "longer than this will be truncated, and sequences shorter than this will be padded.") + parser.add_argument("--doc_stride", default=128, type=int, + help="When splitting up a long document into chunks, how much stride to take between chunks.") + parser.add_argument("--max_query_length", default=64, type=int, + help="The maximum number of tokens for the question. Questions longer than this will " + "be truncated to this length.") + parser.add_argument("--do_train", action='store_true', help="Whether to run training.") + parser.add_argument("--do_predict", action='store_true', help="Whether to run eval on the dev set.") + parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") + parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.") + parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument("--num_train_epochs", default=3.0, type=float, + help="Total number of training epochs to perform.") + parser.add_argument("--warmup_proportion", default=0.1, type=float, + help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% " + "of training.") + parser.add_argument("--n_best_size", default=20, type=int, + help="The total number of n-best predictions to generate in the nbest_predictions.json " + "output file.") + parser.add_argument("--max_answer_length", default=30, type=int, + help="The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another.") + parser.add_argument("--verbose_logging", action='store_true', + help="If true, all of the warnings related to data processing will be printed. " + "A number of warnings are expected for a normal SQuAD evaluation.") + parser.add_argument("--no_cuda", + action='store_true', + help="Whether not to use CUDA when available") + parser.add_argument('--seed', + type=int, + default=42, + help="random seed for initialization") + parser.add_argument('--gradient_accumulation_steps', + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument("--do_lower_case", + action='store_true', + help="Whether to lower case the input text. True for uncased models, False for cased models.") + parser.add_argument("--local_rank", + type=int, + default=-1, + help="local_rank for distributed training on gpus") + parser.add_argument('--fp16', + action='store_true', + help="Whether to use 16-bit float precision instead of 32-bit") + parser.add_argument('--overwrite_output_dir', + action='store_true', + help="Overwrite the content of the output directory") + parser.add_argument('--loss_scale', + type=float, default=0, + help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" + "0 (default value): dynamic loss scaling.\n" + "Positive power of 2: static loss scaling value.\n") + parser.add_argument('--version_2_with_negative', + action='store_true', + help='If true, the SQuAD examples contain some that do not have an answer.') + parser.add_argument('--null_score_diff_threshold', + type=float, default=0.0, + help="If null_score - best_non_null is greater than the threshold predict null.") + parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") + parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") + args = parser.parse_args() + print(args) + + if args.server_ip and args.server_port: + # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script + import ptvsd + print("Waiting for debugger attach") + ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) + ptvsd.wait_for_attach() + + if args.local_rank == -1 or args.no_cuda: + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + n_gpu = torch.cuda.device_count() + else: + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + n_gpu = 1 + # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.distributed.init_process_group(backend='nccl') + + logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + + logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( + device, n_gpu, bool(args.local_rank != -1), args.fp16)) + + if args.gradient_accumulation_steps < 1: + raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( + args.gradient_accumulation_steps)) + + args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + if not args.do_train and not args.do_predict: + raise ValueError("At least one of `do_train` or `do_predict` must be True.") + + if args.do_train: + if not args.train_file: + raise ValueError( + "If `do_train` is True, then `train_file` must be specified.") + if args.do_predict: + if not args.predict_file: + raise ValueError( + "If `do_predict` is True, then `predict_file` must be specified.") + + if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: + raise ValueError("Output directory {} already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab + + tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) + model = BertForQuestionAnswering.from_pretrained(args.bert_model) + if args.local_rank == 0: + torch.distributed.barrier() + + if args.fp16: + model.half() + model.to(device) + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel(model, + device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True) + elif n_gpu > 1: + model = torch.nn.DataParallel(model) + + if args.do_train: + if args.local_rank in [-1, 0]: + tb_writer = SummaryWriter() + # Prepare data loader + train_examples = read_squad_examples( + input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative) + cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format( + list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length)) + try: + with open(cached_train_features_file, "rb") as reader: + train_features = pickle.load(reader) + except: + train_features = convert_examples_to_features( + examples=train_examples, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + doc_stride=args.doc_stride, + max_query_length=args.max_query_length, + is_training=True) + if args.local_rank == -1 or torch.distributed.get_rank() == 0: + logger.info(" Saving train features into cached file %s", cached_train_features_file) + with open(cached_train_features_file, "wb") as writer: + pickle.dump(train_features, writer) + + all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) + all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) + all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) + all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) + all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) + train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, + all_start_positions, all_end_positions) + if args.local_rank == -1: + train_sampler = RandomSampler(train_data) + else: + train_sampler = DistributedSampler(train_data) + + train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) + num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs + # if args.local_rank != -1: + # num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() + + # Prepare optimizer + param_optimizer = list(model.named_parameters()) + + # hack to remove pooler, which is not used + # thus it produce None grad that break apex + param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] + + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, + {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + + if args.fp16: + try: + from apex.optimizers import FP16_Optimizer + from apex.optimizers import FusedAdam + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + + optimizer = FusedAdam(optimizer_grouped_parameters, + lr=args.learning_rate, + bias_correction=False, + max_grad_norm=1.0) + if args.loss_scale == 0: + optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) + else: + optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) + warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, + t_total=num_train_optimization_steps) + else: + optimizer = BertAdam(optimizer_grouped_parameters, + lr=args.learning_rate, + warmup=args.warmup_proportion, + t_total=num_train_optimization_steps) + + global_step = 0 + + logger.info("***** Running training *****") + logger.info(" Num orig examples = %d", len(train_examples)) + logger.info(" Num split examples = %d", len(train_features)) + logger.info(" Batch size = %d", args.train_batch_size) + logger.info(" Num steps = %d", num_train_optimization_steps) + + model.train() + for epoch in trange(int(args.num_train_epochs), desc="Epoch"): + for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])): + if n_gpu == 1: + batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self + input_ids, input_mask, segment_ids, start_positions, end_positions = batch + loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions) + if n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu. + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + if args.fp16: + optimizer.backward(loss) + else: + loss.backward() + if (step + 1) % args.gradient_accumulation_steps == 0: + if args.fp16: + # modify learning rate with special warm up BERT uses + # if args.fp16 is False, BertAdam is used and handles this automatically + lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion) + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + optimizer.step() + optimizer.zero_grad() + global_step += 1 + if args.local_rank in [-1, 0]: + if not args.fp16: + tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step) + tb_writer.add_scalar('loss', loss.item(), global_step) + + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + # Save a trained model, configuration and tokenizer + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self + + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) + output_config_file = os.path.join(args.output_dir, CONFIG_NAME) + + torch.save(model_to_save.state_dict(), output_model_file) + model_to_save.config.to_json_file(output_config_file) + tokenizer.save_vocabulary(args.output_dir) + + # Load a trained model and vocabulary that you have fine-tuned + model = BertForQuestionAnswering.from_pretrained(args.output_dir) + tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) + + # Good practice: save your training arguments together with the trained model + output_args_file = os.path.join(args.output_dir, 'training_args.bin') + torch.save(args, output_args_file) + else: + model = BertForQuestionAnswering.from_pretrained(args.bert_model) + + model.to(device) + + if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + eval_examples = read_squad_examples( + input_file=args.predict_file, is_training=False, version_2_with_negative=args.version_2_with_negative) + eval_features = convert_examples_to_features( + examples=eval_examples, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + doc_stride=args.doc_stride, + max_query_length=args.max_query_length, + is_training=False) + + logger.info("***** Running predictions *****") + logger.info(" Num orig examples = %d", len(eval_examples)) + logger.info(" Num split examples = %d", len(eval_features)) + logger.info(" Batch size = %d", args.predict_batch_size) + + all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) + all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) + all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) + all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) + eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) + # Run prediction for full data + eval_sampler = SequentialSampler(eval_data) + eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) + + model.eval() + all_results = [] + logger.info("Start evaluating") + for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]): + if len(all_results) % 1000 == 0: + logger.info("Processing example: %d" % (len(all_results))) + input_ids = input_ids.to(device) + input_mask = input_mask.to(device) + segment_ids = segment_ids.to(device) + with torch.no_grad(): + batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask) + for i, example_index in enumerate(example_indices): + start_logits = batch_start_logits[i].detach().cpu().tolist() + end_logits = batch_end_logits[i].detach().cpu().tolist() + eval_feature = eval_features[example_index.item()] + unique_id = int(eval_feature.unique_id) + all_results.append(RawResult(unique_id=unique_id, + start_logits=start_logits, + end_logits=end_logits)) + output_prediction_file = os.path.join(args.output_dir, "predictions.json") + output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json") + output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json") + write_predictions(eval_examples, eval_features, all_results, + args.n_best_size, args.max_answer_length, + args.do_lower_case, output_prediction_file, + output_nbest_file, output_null_log_odds_file, args.verbose_logging, + args.version_2_with_negative, args.null_score_diff_threshold) + + +if __name__ == "__main__": + main() diff --git a/examples/test_examples.py b/examples/test_examples.py new file mode 100644 index 00000000000..fada43dae2c --- /dev/null +++ b/examples/test_examples.py @@ -0,0 +1,48 @@ +# coding=utf-8 +# Copyright 2018 HuggingFace Inc.. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import unittest +import argparse + +try: + # python 3.4+ can use builtin unittest.mock instead of mock package + from unittest.mock import patch +except ImportError: + from mock import patch + +import run_bert_squad as rbs + +def get_setup_file(): + parser = argparse.ArgumentParser() + parser.add_argument('-f') + args = parser.parse_args() + return args.f + +class ExamplesTests(unittest.TestCase): + + def test_run_squad(self): + testargs = ["prog", "-f", "/home/test/setup.py"] + with patch.object(sys, 'argv', testargs): + setup = get_setup_file() + assert setup == "/home/test/setup.py" + # rbs.main() + + +if __name__ == "__main__": + unittest.main() diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index cbd007f8722..6dd78dfd025 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -5,6 +5,7 @@ from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE from .tokenization_xlm import XLMTokenizer +from .tokenization_utils import (PreTrainedTokenizer, clean_up_tokenization) from .modeling_bert import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, @@ -26,11 +27,10 @@ from .modeling_xlnet import (XLNetConfig, from .modeling_xlm import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering) +from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, + PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) from .optimization import BertAdam from .optimization_openai import OpenAIAdam from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path) - -from .model_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, - PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index d4967b37184..b2a456209d1 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -29,7 +29,7 @@ from torch import nn from torch.nn import CrossEntropyLoss, MSELoss from .file_utils import cached_path -from .model_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer +from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer logger = logging.getLogger(__name__) diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index c16ad2f7634..090763cda18 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss from torch.nn.parameter import Parameter from .file_utils import cached_path -from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, +from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer, SequenceSummary) from .modeling_bert import BertLayerNorm as LayerNorm diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index 1a3e7fbbb47..b715b183713 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -31,7 +31,7 @@ from torch.nn import CrossEntropyLoss from torch.nn.parameter import Parameter from .file_utils import cached_path -from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, +from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer, SequenceSummary) from .modeling_bert import BertLayerNorm as LayerNorm diff --git a/pytorch_transformers/modeling_transfo_xl.py b/pytorch_transformers/modeling_transfo_xl.py index 9a882bce96b..465577b0028 100644 --- a/pytorch_transformers/modeling_transfo_xl.py +++ b/pytorch_transformers/modeling_transfo_xl.py @@ -37,7 +37,7 @@ from torch.nn.parameter import Parameter from .modeling_bert import BertLayerNorm as LayerNorm from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits from .file_utils import cached_path -from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel +from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel logger = logging.getLogger(__name__) diff --git a/pytorch_transformers/model_utils.py b/pytorch_transformers/modeling_utils.py similarity index 98% rename from pytorch_transformers/model_utils.py rename to pytorch_transformers/modeling_utils.py index 051fbdefbce..b72707ce08a 100644 --- a/pytorch_transformers/model_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -598,9 +598,3 @@ def prune_layer(layer, index, dim=None): return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim) else: raise ValueError("Can't prune layer of class {}".format(layer.__class__)) - -def clean_up_tokenization(out_string): - out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',' - ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" - ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") - return out_string diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 6decba3ccee..14f8848a42d 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -35,7 +35,7 @@ from torch.nn import functional as F from torch.nn import CrossEntropyLoss, MSELoss from .file_utils import cached_path -from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, +from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer, SequenceSummary, SQuADHead) logger = logging.getLogger(__name__) diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index f5841e0601d..289dcbd9db6 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -32,7 +32,7 @@ from torch.nn import functional as F from torch.nn import CrossEntropyLoss, MSELoss from .file_utils import cached_path -from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, +from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits) diff --git a/pytorch_transformers/tests/model_utils_test.py b/pytorch_transformers/tests/model_utils_test.py deleted file mode 100644 index 120df35f820..00000000000 --- a/pytorch_transformers/tests/model_utils_test.py +++ /dev/null @@ -1,50 +0,0 @@ -# coding=utf-8 -# Copyright 2018 HuggingFace Inc.. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import unittest -import json -import random -import shutil -import pytest - -import torch - -from pytorch_transformers import PretrainedConfig, PreTrainedModel -from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP - - -class ModelUtilsTest(unittest.TestCase): - def test_model_from_pretrained(self): - for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: - config = BertConfig.from_pretrained(model_name) - self.assertIsNotNone(config) - self.assertIsInstance(config, PretrainedConfig) - - model = BertModel.from_pretrained(model_name) - self.assertIsNotNone(model) - self.assertIsInstance(model, PreTrainedModel) - - config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True) - model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True) - self.assertEqual(model.config.output_attentions, True) - self.assertEqual(model.config.output_hidden_states, True) - self.assertEqual(model.config, config) - -if __name__ == "__main__": - unittest.main() diff --git a/pytorch_transformers/tests/modeling_bert_test.py b/pytorch_transformers/tests/modeling_bert_test.py index b140f5e6473..2ba59317be0 100644 --- a/pytorch_transformers/tests/modeling_bert_test.py +++ b/pytorch_transformers/tests/modeling_bert_test.py @@ -26,7 +26,7 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM, BertForTokenClassification, BertForMultipleChoice) from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP -from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) +from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) class BertModelTest(unittest.TestCase): diff --git a/pytorch_transformers/tests/modeling_gpt2_test.py b/pytorch_transformers/tests/modeling_gpt2_test.py index 4ace52571a5..7400c9f64dd 100644 --- a/pytorch_transformers/tests/modeling_gpt2_test.py +++ b/pytorch_transformers/tests/modeling_gpt2_test.py @@ -28,7 +28,7 @@ import torch from pytorch_transformers import (GPT2Config, GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) -from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester) +from .modeling_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester) class GPT2ModelTest(unittest.TestCase): diff --git a/pytorch_transformers/tests/modeling_openai_test.py b/pytorch_transformers/tests/modeling_openai_test.py index fe811570237..27263ecb245 100644 --- a/pytorch_transformers/tests/modeling_openai_test.py +++ b/pytorch_transformers/tests/modeling_openai_test.py @@ -24,7 +24,7 @@ import torch from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) -from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester) +from .modeling_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester) class OpenAIModelTest(unittest.TestCase): diff --git a/pytorch_transformers/tests/model_tests_commons.py b/pytorch_transformers/tests/modeling_tests_commons.py similarity index 100% rename from pytorch_transformers/tests/model_tests_commons.py rename to pytorch_transformers/tests/modeling_tests_commons.py diff --git a/pytorch_transformers/tests/modeling_transfo_xl_test.py b/pytorch_transformers/tests/modeling_transfo_xl_test.py index d15a19eb64f..f2906d879fe 100644 --- a/pytorch_transformers/tests/modeling_transfo_xl_test.py +++ b/pytorch_transformers/tests/modeling_transfo_xl_test.py @@ -28,7 +28,7 @@ import torch from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP -from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor +from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor class TransfoXLModelTest(unittest.TestCase): class TransfoXLModelTester(object): diff --git a/examples/tests/examples_tests.py b/pytorch_transformers/tests/modeling_utils_test.py similarity index 92% rename from examples/tests/examples_tests.py rename to pytorch_transformers/tests/modeling_utils_test.py index 120df35f820..1866d353538 100644 --- a/examples/tests/examples_tests.py +++ b/pytorch_transformers/tests/modeling_utils_test.py @@ -16,17 +16,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import unittest -import json -import random -import shutil -import pytest - -import torch from pytorch_transformers import PretrainedConfig, PreTrainedModel -from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP +from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP class ModelUtilsTest(unittest.TestCase): diff --git a/pytorch_transformers/tests/modeling_xlm_test.py b/pytorch_transformers/tests/modeling_xlm_test.py index 8a8905cc312..9c511f21a86 100644 --- a/pytorch_transformers/tests/modeling_xlm_test.py +++ b/pytorch_transformers/tests/modeling_xlm_test.py @@ -23,7 +23,7 @@ import pytest from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification) from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP -from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) +from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor) class XLMModelTest(unittest.TestCase): diff --git a/pytorch_transformers/tests/modeling_xlnet_test.py b/pytorch_transformers/tests/modeling_xlnet_test.py index b9d55a26c72..b762426d2c7 100644 --- a/pytorch_transformers/tests/modeling_xlnet_test.py +++ b/pytorch_transformers/tests/modeling_xlnet_test.py @@ -28,7 +28,7 @@ import torch from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering) from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP -from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor +from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor class XLNetModelTest(unittest.TestCase): class XLNetModelTester(object): diff --git a/pytorch_transformers/tests/tokenization_bert_test.py b/pytorch_transformers/tests/tokenization_bert_test.py index 59a87a4cb99..37e20cc2865 100644 --- a/pytorch_transformers/tests/tokenization_bert_test.py +++ b/pytorch_transformers/tests/tokenization_bert_test.py @@ -24,7 +24,7 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer, BertTokenizer, WordpieceTokenizer, _is_control, _is_punctuation, - _is_whitespace, PRETRAINED_VOCAB_ARCHIVE_MAP) + _is_whitespace) from .tokenization_tests_commons import create_and_check_tokenizer_commons @@ -49,14 +49,6 @@ class TokenizationTest(unittest.TestCase): os.remove(vocab_file) - @pytest.mark.slow - def test_tokenizer_from_pretrained(self): - cache_dir = "/tmp/pytorch_transformers_test/" - for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: - tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) - self.assertIsNotNone(tokenizer) - def test_chinese(self): tokenizer = BasicTokenizer() diff --git a/pytorch_transformers/tests/tokenization_gpt2_test.py b/pytorch_transformers/tests/tokenization_gpt2_test.py index c6d926bdd40..8b06161b53c 100644 --- a/pytorch_transformers/tests/tokenization_gpt2_test.py +++ b/pytorch_transformers/tests/tokenization_gpt2_test.py @@ -17,10 +17,8 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import unittest import json -import shutil -import pytest -from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP +from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer from .tokenization_tests_commons import create_and_check_tokenizer_commons @@ -56,13 +54,6 @@ class GPT2TokenizationTest(unittest.TestCase): os.remove(vocab_file) os.remove(merges_file) - # @pytest.mark.slow - def test_tokenizer_from_pretrained(self): - cache_dir = "/tmp/pytorch_transformers_test/" - for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: - tokenizer = GPT2Tokenizer.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) - self.assertIsNotNone(tokenizer) if __name__ == '__main__': unittest.main() diff --git a/pytorch_transformers/tests/tokenization_openai_test.py b/pytorch_transformers/tests/tokenization_openai_test.py index 38315f927b5..3f8c49f8886 100644 --- a/pytorch_transformers/tests/tokenization_openai_test.py +++ b/pytorch_transformers/tests/tokenization_openai_test.py @@ -20,7 +20,7 @@ import json import shutil import pytest -from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP +from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer from.tokenization_tests_commons import create_and_check_tokenizer_commons @@ -58,14 +58,6 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): self.assertListEqual( tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) - @pytest.mark.slow - def test_tokenizer_from_pretrained(self): - cache_dir = "/tmp/pytorch_transformers_test/" - for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: - tokenizer = OpenAIGPTTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) - self.assertIsNotNone(tokenizer) - if __name__ == '__main__': unittest.main() diff --git a/pytorch_transformers/tests/tokenization_transfo_xl_test.py b/pytorch_transformers/tests/tokenization_transfo_xl_test.py index f744e319c8c..f583e30b56b 100644 --- a/pytorch_transformers/tests/tokenization_transfo_xl_test.py +++ b/pytorch_transformers/tests/tokenization_transfo_xl_test.py @@ -20,7 +20,7 @@ from io import open import shutil import pytest -from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP +from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer from.tokenization_tests_commons import create_and_check_tokenizer_commons @@ -59,13 +59,6 @@ class TransfoXLTokenizationTest(unittest.TestCase): tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]) - @pytest.mark.slow - def test_tokenizer_from_pretrained(self): - cache_dir = "/tmp/pytorch_transformers_test/" - for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: - tokenizer = TransfoXLTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) - self.assertIsNotNone(tokenizer) if __name__ == '__main__': unittest.main() diff --git a/pytorch_transformers/tests/tokenization_utils_test.py b/pytorch_transformers/tests/tokenization_utils_test.py new file mode 100644 index 00000000000..e8856d50c2c --- /dev/null +++ b/pytorch_transformers/tests/tokenization_utils_test.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# Copyright 2018 HuggingFace Inc.. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +from pytorch_transformers import PreTrainedTokenizer +from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer + +class TokenizerUtilsTest(unittest.TestCase): + def check_tokenizer_from_pretrained(self, tokenizer_class): + s3_models = list(tokenizer_class.max_model_input_sizes.keys()) + for model_name in s3_models[:1]: + tokenizer = tokenizer_class.from_pretrained(model_name) + self.assertIsNotNone(tokenizer) + self.assertIsInstance(tokenizer, PreTrainedTokenizer) + + def test_pretrained_tokenizers(self): + self.check_tokenizer_from_pretrained(GPT2Tokenizer) + +if __name__ == "__main__": + unittest.main() diff --git a/pytorch_transformers/tests/tokenization_xlm_test.py b/pytorch_transformers/tests/tokenization_xlm_test.py index 9cc18f3d605..00d273a628f 100644 --- a/pytorch_transformers/tests/tokenization_xlm_test.py +++ b/pytorch_transformers/tests/tokenization_xlm_test.py @@ -20,9 +20,9 @@ import json import shutil import pytest -from pytorch_transformers.tokenization_xlm import XLMTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP +from pytorch_transformers.tokenization_xlm import XLMTokenizer -from.tokenization_tests_commons import create_and_check_tokenizer_commons +from .tokenization_tests_commons import create_and_check_tokenizer_commons class XLMTokenizationTest(unittest.TestCase): @@ -57,14 +57,6 @@ class XLMTokenizationTest(unittest.TestCase): self.assertListEqual( tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) - @pytest.mark.slow - def test_tokenizer_from_pretrained(self): - cache_dir = "/tmp/pytorch_transformers_test/" - for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: - tokenizer = XLMTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) - self.assertIsNotNone(tokenizer) - if __name__ == '__main__': unittest.main() diff --git a/pytorch_transformers/tests/tokenization_xlnet_test.py b/pytorch_transformers/tests/tokenization_xlnet_test.py index 4dd76e114bb..6e81f214b76 100644 --- a/pytorch_transformers/tests/tokenization_xlnet_test.py +++ b/pytorch_transformers/tests/tokenization_xlnet_test.py @@ -19,9 +19,7 @@ import unittest import shutil import pytest -from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, - PRETRAINED_VOCAB_ARCHIVE_MAP, - SPIECE_UNDERLINE) +from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) from.tokenization_tests_commons import create_and_check_tokenizer_commons @@ -60,14 +58,6 @@ class XLNetTokenizationTest(unittest.TestCase): SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'', u'.']) - @pytest.mark.slow - def test_tokenizer_from_pretrained(self): - cache_dir = "/tmp/pytorch_transformers_test/" - for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: - tokenizer = XLNetTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - shutil.rmtree(cache_dir) - self.assertIsNotNone(tokenizer) - def test_tokenizer_lower(self): tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index c8db62b9c01..b26e5066e93 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -23,11 +23,15 @@ import unicodedata from io import open from .file_utils import cached_path -from .model_utils import clean_up_tokenization +from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization logger = logging.getLogger(__name__) -PRETRAINED_VOCAB_ARCHIVE_MAP = { +VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': + { 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", @@ -41,8 +45,9 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = { 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", -} -PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { +}} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'bert-base-uncased': 512, 'bert-large-uncased': 512, 'bert-base-cased': 512, @@ -57,7 +62,6 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 'bert-large-cased-whole-word-masking-finetuned-squad': 512, 'bert-base-cased-finetuned-mrpc': 512, } -VOCAB_NAME = 'vocab.txt' def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" @@ -83,8 +87,11 @@ def whitespace_tokenize(text): return tokens -class BertTokenizer(object): +class BertTokenizer(PreTrainedTokenizer): """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): @@ -203,7 +210,7 @@ class BertTokenizer(object): """Save the tokenizer vocabulary to a directory or file.""" index = 0 if os.path.isdir(vocab_path): - vocab_file = os.path.join(vocab_path, VOCAB_NAME) + vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) with open(vocab_file, "w", encoding="utf-8") as writer: for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): if index != token_index: @@ -215,13 +222,10 @@ class BertTokenizer(object): return (vocab_file,) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + """ Instantiate a BertTokenizer from pre-trained vocabulary files. """ - Instantiate a PreTrainedBertModel from a pre-trained model file. - Download and cache the pre-trained model file if needed. - """ - if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES: if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): logger.warning("The pre-trained model you are loading is a cased model but you have not set " "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " @@ -232,40 +236,8 @@ class BertTokenizer(object): "`do_lower_case` to False. We are setting `do_lower_case=True` for you " "but you may want to check this behavior.") kwargs['do_lower_case'] = True - else: - vocab_file = pretrained_model_name_or_path - if os.path.isdir(vocab_file): - vocab_file = os.path.join(vocab_file, VOCAB_NAME) - # redirect to the cache, if necessary - try: - resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) - except EnvironmentError: - if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - logger.error( - "Couldn't reach server at '{}' to download vocabulary.".format( - vocab_file)) - else: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find any file " - "associated to this path or url.".format( - pretrained_model_name_or_path, - ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), - vocab_file)) - return None - if resolved_vocab_file == vocab_file: - logger.info("loading vocabulary file {}".format(vocab_file)) - else: - logger.info("loading vocabulary file {} from cache at {}".format( - vocab_file, resolved_vocab_file)) - if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: - # if we're using a pretrained model, ensure the tokenizer wont index sequences longer - # than the number of positional embeddings - max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] - kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) - # Instantiate tokenizer. - tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) - return tokenizer + + return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) class BasicTokenizer(object): diff --git a/pytorch_transformers/tokenization_gpt2.py b/pytorch_transformers/tokenization_gpt2.py index 2947ce66b83..abdfe39c1cd 100644 --- a/pytorch_transformers/tokenization_gpt2.py +++ b/pytorch_transformers/tokenization_gpt2.py @@ -23,8 +23,6 @@ import os import regex as re from io import open -from .model_utils import clean_up_tokenization - try: from functools import lru_cache except ImportError: @@ -33,24 +31,38 @@ except ImportError: def lru_cache(): return lambda func: func -from .file_utils import cached_path +from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization logger = logging.getLogger(__name__) -PRETRAINED_VOCAB_ARCHIVE_MAP = { - 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", - 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", +VOCAB_FILES_NAMES = { + 'vocab_file': 'vocab.json', + 'merges_file': 'merges.txt', + 'special_tokens_file': 'special_tokens.txt' } -PRETRAINED_MERGES_ARCHIVE_MAP = { - 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", - 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': + { + 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", + 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", + }, + 'merges_file': + { + 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", + 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", + }, + 'special_tokens_file': + { + 'gpt2': None, + 'gpt2-medium': None, + } } -PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'gpt2': 1024, + 'gpt2-medium': 1024, } -VOCAB_NAME = 'vocab.json' -MERGES_NAME = 'merges.txt' -SPECIAL_TOKENS_NAME = 'special_tokens.txt' @lru_cache() def bytes_to_unicode(): @@ -87,70 +99,16 @@ def get_pairs(word): prev_char = char return pairs -class GPT2Tokenizer(object): +class GPT2Tokenizer(PreTrainedTokenizer): """ GPT-2 BPE tokenizer. Peculiarities: - Byte-level BPE """ - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): - """ - Instantiate a GPT2Tokenizer from a pre-trained model file. - Download and cache the pre-trained model file if needed. - """ - if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] - merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] - special_tokens_file = None - else: - vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) - merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) - special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) - if not os.path.exists(special_tokens_file): - special_tokens_file = None - else: - logger.info("loading special tokens file {}".format(special_tokens_file)) - # redirect to the cache, if necessary - try: - resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) - resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) - except EnvironmentError: - if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - logger.error( - "Couldn't reach server at '{}' to download vocabulary.".format( - vocab_file)) - else: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find files {} and {} " - "at this path or url.".format( - pretrained_model_name_or_path, - ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), - pretrained_model_name_or_path, - vocab_file, merges_file)) - return None - if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: - logger.info("loading vocabulary file {}".format(vocab_file)) - logger.info("loading merges file {}".format(merges_file)) - else: - logger.info("loading vocabulary file {} from cache at {}".format( - vocab_file, resolved_vocab_file)) - logger.info("loading merges file {} from cache at {}".format( - merges_file, resolved_merges_file)) - if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: - # if we're using a pretrained model, ensure the tokenizer wont index sequences longer - # than the number of positional embeddings - max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] - kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) - # Instantiate tokenizer. - if special_tokens_file and 'special_tokens' not in kwargs: - special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] - else: - special_tokens = kwargs.pop('special_tokens', []) - tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) - return tokenizer + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): + def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, errors='replace', max_len=None): self.max_len = max_len if max_len is not None else int(1e12) self.encoder = json.load(open(vocab_file)) self.decoder = {v:k for k,v in self.encoder.items()} @@ -165,9 +123,16 @@ class GPT2Tokenizer(object): # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + all_special_tokens = [] + if special_tokens_file is not None: + special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + all_special_tokens.extend(special_tokens_to_add) + if special_tokens is not None and special_tokens: + all_special_tokens.extend(special_tokens) + self.special_tokens = {} self.special_tokens_decoder = {} - self.set_special_tokens(special_tokens) + self.set_special_tokens(all_special_tokens) def __len__(self): return len(self.encoder) + len(self.special_tokens) @@ -285,9 +250,9 @@ class GPT2Tokenizer(object): if not os.path.isdir(vocab_path): logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) return - vocab_file = os.path.join(vocab_path, VOCAB_NAME) - merge_file = os.path.join(vocab_path, MERGES_NAME) - special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) + vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) + merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file']) + special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file']) with open(vocab_file, 'w', encoding='utf-8') as f: f.write(json.dumps(self.encoder, ensure_ascii=False)) diff --git a/pytorch_transformers/tokenization_openai.py b/pytorch_transformers/tokenization_openai.py index 7d005a82600..419dfdad921 100644 --- a/pytorch_transformers/tokenization_openai.py +++ b/pytorch_transformers/tokenization_openai.py @@ -26,23 +26,35 @@ from io import open from tqdm import tqdm from .file_utils import cached_path -from .model_utils import clean_up_tokenization +from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization from .tokenization_bert import BasicTokenizer logger = logging.getLogger(__name__) -PRETRAINED_VOCAB_ARCHIVE_MAP = { - 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", +VOCAB_FILES_NAMES = { + 'vocab_file': 'vocab.json', + 'merges_file': 'merges.txt', + 'special_tokens_file': 'special_tokens.txt' } -PRETRAINED_MERGES_ARCHIVE_MAP = { - 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': + { + 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", + }, + 'merges_file': + { + 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", + }, + 'special_tokens_file': + { + 'openai-gpt': None, + } } -PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'openai-gpt': 512, } -VOCAB_NAME = 'vocab.json' -MERGES_NAME = 'merges.txt' -SPECIAL_TOKENS_NAME = 'special_tokens.txt' def get_pairs(word): """ @@ -71,7 +83,7 @@ def text_standardize(text): text = re.sub(r'[^\S\n]+', ' ', text) return text.strip() -class OpenAIGPTTokenizer(object): +class OpenAIGPTTokenizer(PreTrainedTokenizer): """ BPE tokenizer. Peculiarities: - lower case all inputs @@ -79,65 +91,11 @@ class OpenAIGPTTokenizer(object): - argument special_tokens and function set_special_tokens: can be used to add additional symbols (ex: "__classify__") to a vocabulary. """ - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): - """ - Instantiate a PreTrainedBertModel from a pre-trained model file. - Download and cache the pre-trained model file if needed. - """ - if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] - merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] - special_tokens_file = None - else: - vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) - merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) - special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) - if not os.path.exists(special_tokens_file): - special_tokens_file = None - else: - logger.info("loading special tokens file {}".format(special_tokens_file)) - # redirect to the cache, if necessary - try: - resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) - resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) - except EnvironmentError: - if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - logger.error( - "Couldn't reach server at '{}' to download vocabulary.".format( - vocab_file)) - else: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find files {} and {} " - "at this path or url.".format( - pretrained_model_name_or_path, - ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), - pretrained_model_name_or_path, - vocab_file, merges_file)) - return None - if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: - logger.info("loading vocabulary file {}".format(vocab_file)) - logger.info("loading merges file {}".format(merges_file)) - else: - logger.info("loading vocabulary file {} from cache at {}".format( - vocab_file, resolved_vocab_file)) - logger.info("loading merges file {} from cache at {}".format( - merges_file, resolved_merges_file)) - if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: - # if we're using a pretrained model, ensure the tokenizer wont index sequences longer - # than the number of positional embeddings - max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] - kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) - # Instantiate tokenizer. - if special_tokens_file and 'special_tokens' not in kwargs: - special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] - else: - special_tokens = kwargs.pop('special_tokens', []) - tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) - return tokenizer + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): + def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, max_len=None): try: import ftfy import spacy @@ -156,9 +114,17 @@ class OpenAIGPTTokenizer(object): merges = [tuple(merge.split()) for merge in merges] self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {} + + all_special_tokens = [] + if special_tokens_file is not None: + special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + all_special_tokens.extend(special_tokens_to_add) + if special_tokens is not None and special_tokens: + all_special_tokens.extend(special_tokens) + self.special_tokens = {} self.special_tokens_decoder = {} - self.set_special_tokens(special_tokens) + self.set_special_tokens(all_special_tokens) def __len__(self): return len(self.encoder) + len(self.special_tokens) @@ -286,9 +252,9 @@ class OpenAIGPTTokenizer(object): if not os.path.isdir(vocab_path): logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) return - vocab_file = os.path.join(vocab_path, VOCAB_NAME) - merge_file = os.path.join(vocab_path, MERGES_NAME) - special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) + vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) + merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file']) + special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file']) with open(vocab_file, 'w', encoding='utf-8') as f: f.write(json.dumps(self.encoder, ensure_ascii=False)) diff --git a/pytorch_transformers/tokenization_transfo_xl.py b/pytorch_transformers/tokenization_transfo_xl.py index 7e83680770b..a86c8fe4600 100644 --- a/pytorch_transformers/tokenization_transfo_xl.py +++ b/pytorch_transformers/tokenization_transfo_xl.py @@ -31,7 +31,7 @@ import torch import numpy as np from .file_utils import cached_path -from .model_utils import clean_up_tokenization +from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization if sys.version_info[0] == 2: import cPickle as pickle @@ -41,66 +41,35 @@ else: logger = logging.getLogger(__name__) -PRETRAINED_VOCAB_ARCHIVE_MAP = { - 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin", +VOCAB_FILES_NAMES = {'pretrained_vocab_file': 'vocab.bin'} + +PRETRAINED_VOCAB_FILES_MAP = { + 'pretrained_vocab_file': + { + 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'transfo-xl-wt103': 512, } -VOCAB_NAME = 'vocab.bin' PRETRAINED_CORPUS_ARCHIVE_MAP = { 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin", } CORPUS_NAME = 'corpus.bin' -class TransfoXLTokenizer(object): +class TransfoXLTokenizer(PreTrainedTokenizer): """ Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl """ - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): - """ - Instantiate a TransfoXLTokenizer. - The TransfoXLTokenizer. - """ - if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] - else: - if os.path.isdir(pretrained_model_name_or_path): - vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) - else: - vocab_file = pretrained_model_name_or_path - # redirect to the cache, if necessary - try: - resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) - except EnvironmentError: - if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - logger.error( - "Couldn't reach server at '{}' to download vocabulary.".format( - vocab_file)) - else: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find files {} " - "at this path or url.".format( - pretrained_model_name_or_path, - ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), - pretrained_model_name_or_path, - vocab_file)) - return None - if resolved_vocab_file == vocab_file: - logger.info("loading vocabulary file {}".format(vocab_file)) - else: - logger.info("loading vocabulary file {} from cache at {}".format( - vocab_file, resolved_vocab_file)) - - # Instantiate tokenizer. - tokenizer = cls(*inputs, **kwargs) - vocab_dict = torch.load(resolved_vocab_file) - for key, value in vocab_dict.items(): - tokenizer.__dict__[key] = value - return tokenizer + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__(self, special=[], min_freq=0, max_size=None, lower_case=False, - delimiter=None, vocab_file=None, never_split=("", "", "")): + delimiter=None, vocab_file=None, pretrained_vocab_file=None, + never_split=("", "", "")): self.counter = Counter() self.special = special self.min_freq = min_freq @@ -110,6 +79,13 @@ class TransfoXLTokenizer(object): self.vocab_file = vocab_file self.never_split = never_split + if pretrained_vocab_file is not None: + # Hack because, honestly this tokenizer was not made to be used + # in a library like ours, at all. + vocab_dict = torch.load(pretrained_vocab_file) + for key, value in vocab_dict.items(): + self.__dict__[key] = value + if vocab_file is not None: self.build_vocab() @@ -157,7 +133,7 @@ class TransfoXLTokenizer(object): """Save the tokenizer vocabulary to a directory or file.""" index = 0 if os.path.isdir(vocab_path): - vocab_file = os.path.join(vocab_path, VOCAB_NAME) + vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['pretrained_vocab_file']) torch.save(self.__dict__, vocab_file) return (vocab_file,) @@ -484,7 +460,7 @@ class TransfoXLCorpus(object): "We assumed '{}' was a path or url but couldn't find files {} " "at this path or url.".format( pretrained_model_name_or_path, - ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + ', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, corpus_file)) return None diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py new file mode 100644 index 00000000000..98a29685399 --- /dev/null +++ b/pytorch_transformers/tokenization_utils.py @@ -0,0 +1,114 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OpenAI GPT.""" +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import sys +import json +import logging +import os +import regex as re +from io import open + +try: + from functools import lru_cache +except ImportError: + # Just a dummy decorator to get the checks to run on python2 + # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. + def lru_cache(): + return lambda func: func + +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + + +class PreTrainedTokenizer(object): + """ An abstract class to handle dowloading and loading pretrained tokenizers. + """ + vocab_files_names = {} + pretrained_vocab_files_map = {} + max_model_input_sizes = {} + + @classmethod + def from_pretrained(cls, *inputs, **kwargs): + return cls._from_pretrained(*inputs, **kwargs) + + @classmethod + def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a PreTrainedTokenizer from pre-trained vocabulary files. + Download and cache the vocabulary files if needed. + """ + s3_models = list(cls.max_model_input_sizes.keys()) + vocab_files = {} + if pretrained_model_name_or_path in s3_models: + for file_id, map_list in cls.pretrained_vocab_files_map.items(): + vocab_files[file_id] = map_list[pretrained_model_name_or_path] + else: + for file_id, file_name in cls.vocab_files_names.items(): + if os.path.isdir(pretrained_model_name_or_path): + full_file_name = os.path.join(pretrained_model_name_or_path, file_name) + else: + full_file_name = pretrained_model_name_or_path + if not os.path.exists(full_file_name): + logger.info("Didn't find file {}. We don't load it.".format(full_file_name)) + full_file_name = None + vocab_files[file_id] = full_file_name + # redirect to the cache, if necessary + try: + resolved_vocab_files = {} + for file_id, file_path in vocab_files.items(): + if file_path is None: + resolved_vocab_files[file_id] = None + else: + resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir) + except EnvironmentError: + if pretrained_model_name_or_path in s3_models: + logger.error("Couldn't reach server to download vocabulary.") + else: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find files {} " + "at this path or url.".format( + pretrained_model_name_or_path, ', '.join(s3_models), + pretrained_model_name_or_path, str(vocab_files.keys()))) + return None + + for file_id, file_path in vocab_files.items(): + if file_path == resolved_vocab_files[file_id]: + logger.info("loading file {}".format(file_path)) + else: + logger.info("loading file {} from cache at {}".format( + file_path, resolved_vocab_files[file_id])) + + if pretrained_model_name_or_path in cls.max_model_input_sizes: + # if we're using a pretrained model, ensure the tokenizer + # wont index sequences longer than the number of positional embeddings + max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + + # Instantiate tokenizer. + tokenizer = cls(*inputs, **resolved_vocab_files, **kwargs) + + return tokenizer + + +def clean_up_tokenization(out_string): + out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',' + ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" + ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") + return out_string diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index 26c73c56b27..e37f3888a30 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -26,30 +26,42 @@ from io import open from tqdm import tqdm from .file_utils import cached_path -from .model_utils import clean_up_tokenization +from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization from .tokenization_bert import BasicTokenizer logger = logging.getLogger(__name__) -PRETRAINED_VOCAB_ARCHIVE_MAP = { - 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json", +VOCAB_FILES_NAMES = { + 'vocab_file': 'vocab.json', + 'merges_file': 'merges.txt', + 'special_tokens_file': 'special_tokens.txt' } -PRETRAINED_MERGES_ARCHIVE_MAP = { - 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt", + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': + { + 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json", + }, + 'merges_file': + { + 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt", + }, + 'special_tokens_file': + { + 'xlm-mlm-en-2048': None, + } } -PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'xlm-mlm-en-2048': 512, } -VOCAB_NAME = 'vocab.json' -MERGES_NAME = 'merges.txt' -SPECIAL_TOKENS_NAME = 'special_tokens.txt' -INDEX= { - "bos_index": 0, - "eos_index": 1, - "pad_index": 2, - "unk_index": 3, - "mask_index": 5 +INDEX = { + "bos_index": 0, + "eos_index": 1, + "pad_index": 2, + "unk_index": 3, + "mask_index": 5 } def get_pairs(word): @@ -79,7 +91,7 @@ def text_standardize(text): text = re.sub(r'[^\S\n]+', ' ', text) return text.strip() -class XLMTokenizer(object): +class XLMTokenizer(PreTrainedTokenizer): """ BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities: - lower case all inputs @@ -87,65 +99,11 @@ class XLMTokenizer(object): - argument special_tokens and function set_special_tokens: can be used to add additional symbols (ex: "__classify__") to a vocabulary. """ - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): - """ - Instantiate a PreTrainedBertModel from a pre-trained model file. - Download and cache the pre-trained model file if needed. - """ - if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] - merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] - special_tokens_file = None - else: - vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) - merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) - special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) - if not os.path.exists(special_tokens_file): - special_tokens_file = None - else: - logger.info("loading special tokens file {}".format(special_tokens_file)) - # redirect to the cache, if necessary - try: - resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) - resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) - except EnvironmentError: - if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - logger.error( - "Couldn't reach server at '{}' to download vocabulary.".format( - vocab_file)) - else: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find files {} and {} " - "at this path or url.".format( - pretrained_model_name_or_path, - ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), - pretrained_model_name_or_path, - vocab_file, merges_file)) - return None - if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: - logger.info("loading vocabulary file {}".format(vocab_file)) - logger.info("loading merges file {}".format(merges_file)) - else: - logger.info("loading vocabulary file {} from cache at {}".format( - vocab_file, resolved_vocab_file)) - logger.info("loading merges file {} from cache at {}".format( - merges_file, resolved_merges_file)) - if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: - # if we're using a pretrained model, ensure the tokenizer wont index sequences longer - # than the number of positional embeddings - max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] - kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) - # Instantiate tokenizer. - if special_tokens_file and 'special_tokens' not in kwargs: - special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] - else: - special_tokens = kwargs.pop('special_tokens', []) - tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) - return tokenizer + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): + def __init__(self, vocab_file, merges_file, special_tokens_file=None, special_tokens=None, max_len=None): try: import ftfy import spacy @@ -164,9 +122,17 @@ class XLMTokenizer(object): merges = [tuple(merge.split()[:2]) for merge in merges] self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {} + + all_special_tokens = [] + if special_tokens_file is not None: + special_tokens_to_add = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + all_special_tokens.extend(special_tokens_to_add) + if special_tokens is not None and special_tokens: + all_special_tokens.extend(special_tokens) + self.special_tokens = {} self.special_tokens_decoder = {} - self.set_special_tokens(special_tokens) + self.set_special_tokens(all_special_tokens) def __len__(self): return len(self.encoder) + len(self.special_tokens) @@ -294,9 +260,9 @@ class XLMTokenizer(object): if not os.path.isdir(vocab_path): logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) return - vocab_file = os.path.join(vocab_path, VOCAB_NAME) - merge_file = os.path.join(vocab_path, MERGES_NAME) - special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) + vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) + merge_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['merges_file']) + special_tokens_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['special_tokens_file']) with open(vocab_file, 'w', encoding='utf-8') as f: f.write(json.dumps(self.encoder, ensure_ascii=False)) diff --git a/pytorch_transformers/tokenization_xlnet.py b/pytorch_transformers/tokenization_xlnet.py index 76b9a9f8707..a30e6db8da0 100644 --- a/pytorch_transformers/tokenization_xlnet.py +++ b/pytorch_transformers/tokenization_xlnet.py @@ -27,15 +27,24 @@ import unicodedata import six from .file_utils import cached_path -from .model_utils import clean_up_tokenization +from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization logger = logging.getLogger(__name__) -PRETRAINED_VOCAB_ARCHIVE_MAP = { +VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'} + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': + { 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model", + } } + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'xlnet-large-cased': 512, +} + VOCAB_NAME = 'spiece.model' -SPECIAL_TOKENS_NAME = 'special_tokens.txt' SPIECE_UNDERLINE = u'▁' @@ -46,7 +55,7 @@ SEG_ID_CLS = 2 SEG_ID_SEP = 3 SEG_ID_PAD = 4 -class XLNetTokenizer(object): +class XLNetTokenizer(PreTrainedTokenizer): """ SentencePiece based tokenizer. Peculiarities: - requires SentencePiece: https://github.com/google/sentencepiece @@ -63,64 +72,11 @@ class XLNetTokenizer(object): "" : 7, "" : 8, } - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): - """ - Instantiate a PreTrainedBertModel from a pre-trained model file. - Download and cache the pre-trained model file if needed. - """ - if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] - special_tokens_file = None - if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): - logger.warning("The pre-trained model you are loading is a cased model but you have not set " - "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " - "you may want to check this behavior.") - kwargs['do_lower_case'] = False - elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): - logger.warning("The pre-trained model you are loading is an uncased model but you have set " - "`do_lower_case` to False. We are setting `do_lower_case=True` for you " - "but you may want to check this behavior.") - kwargs['do_lower_case'] = True - else: - vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) - special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) - if not os.path.exists(special_tokens_file): - special_tokens_file = None - else: - logger.info("loading special tokens file {}".format(special_tokens_file)) - # redirect to the cache, if necessary - try: - resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) - except EnvironmentError: - if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: - logger.error( - "Couldn't reach server at '{}' to download vocabulary.".format( - vocab_file)) - else: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find files {}" - "at this path or url.".format( - pretrained_model_name_or_path, - ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), - pretrained_model_name_or_path, - vocab_file)) - return None - if resolved_vocab_file == vocab_file: - logger.info("loading vocabulary file {}".format(vocab_file)) - else: - logger.info("loading vocabulary file {} from cache at {}".format( - vocab_file, resolved_vocab_file)) - # Instantiate tokenizer. - if special_tokens_file and 'special_tokens' not in kwargs: - special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] - else: - special_tokens = kwargs.pop('special_tokens', []) - tokenizer = cls(resolved_vocab_file, special_tokens=special_tokens, *inputs, **kwargs) - return tokenizer + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, special_tokens=None, max_len=None, + def __init__(self, vocab_file, max_len=None, do_lower_case=False, remove_space=True, keep_accents=False): try: import sentencepiece as spm @@ -136,9 +92,6 @@ class XLNetTokenizer(object): self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(vocab_file) - self.special_tokens = {} - self.special_tokens_decoder = {} - self.set_special_tokens(special_tokens) @property def UNK_TOKEN(self): @@ -181,7 +134,7 @@ class XLNetTokenizer(object): return self.special_symbols[""] def __len__(self): - return len(self.encoder) + len(self.special_tokens) + return len(self.sp_model) def __getstate__(self): state = self.__dict__.copy() @@ -198,19 +151,6 @@ class XLNetTokenizer(object): self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(self.vocab_file) - def set_special_tokens(self, special_tokens): - """ Add a list of additional tokens to the encoder. - The additional tokens are indexed starting from the last index of the - current vocabulary in the order of the `special_tokens` list. - """ - if not special_tokens: - self.special_tokens = {} - self.special_tokens_decoder = {} - return - self.special_tokens = dict((tok, len(self.sp_model) + i) for i, tok in enumerate(special_tokens)) - self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} - logger.info("Special tokens: %s", str(self.special_tokens)) - def preprocess_text(self, inputs): if self.remove_space: outputs = ' '.join(inputs.strip().split()) @@ -272,15 +212,9 @@ class XLNetTokenizer(object): """ Converts a sequence of tokens into ids using the vocab. """ ids = [] if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): - if tokens in self.special_tokens: - return self.special_tokens[tokens] - else: - return self.sp_model.PieceToId(tokens) + return self.sp_model.PieceToId(tokens) for token in tokens: - if token in self.special_tokens: - ids.append(self.special_tokens[token]) - else: - ids.append(self.sp_model.PieceToId(token)) + ids.append(self.sp_model.PieceToId(token)) if len(ids) > self.max_len: logger.warning( "Token indices sequence length is longer than the specified maximum " @@ -289,15 +223,11 @@ class XLNetTokenizer(object): ) return ids - def convert_ids_to_tokens(self, ids, return_unicode=True, skip_special_tokens=False): + def convert_ids_to_tokens(self, ids, return_unicode=True): """Converts a sequence of ids in tokens.""" tokens = [] for i in ids: - if i in self.special_tokens_decoder: - if not skip_special_tokens: - tokens.append(self.special_tokens_decoder[i]) - else: - tokens.append(self.sp_model.IdToPiece(i)) + tokens.append(self.sp_model.IdToPiece(i)) if six.PY2 and return_unicode: ret_pieces = [] @@ -311,9 +241,9 @@ class XLNetTokenizer(object): def encode(self, text, sample=False): return self.convert_tokens_to_ids(self.tokenize(text, sample=sample)) - def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): + def decode(self, ids, clean_up_tokenization_spaces=True): """Converts a sequence of ids in a string.""" - tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) + tokens = self.convert_ids_to_tokens(ids) out_string = ''.join(tokens) if clean_up_tokenization_spaces: out_string = out_string.strip().replace('', '') @@ -328,18 +258,7 @@ class XLNetTokenizer(object): logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) return out_vocab_file = os.path.join(vocab_path, VOCAB_NAME) - special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) copyfile(self.vocab_file, out_vocab_file) - index = len(self.sp_model) - with open(special_tokens_file, 'w', encoding='utf-8') as writer: - for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): - if index != token_index: - logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." - " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) - index = token_index - writer.write(token + u'\n') - index += 1 - - return out_vocab_file, special_tokens_file + return (out_vocab_file,)