diff --git a/examples/eval_transfo_xl.py b/examples/eval_transfo_xl.py new file mode 100644 index 00000000000..886e826b2c8 --- /dev/null +++ b/examples/eval_transfo_xl.py @@ -0,0 +1,151 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Transformer XL model evaluation script. + Adapted from https://github.com/kimiyoung/transformer-xl. + In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py +""" +import os +import sys +import functools +import argparse +import time +import math + +import torch + +from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus + +def logging(s, log_path, print_=True, log_=True): + if print_: + print(s) + if log_: + with open(log_path, 'a+') as f_log: + f_log.write(s + '\n') + +def get_logger(log_path, **kwargs): + return functools.partial(logging, log_path=log_path, **kwargs) + +parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') +# parser.add_argument('--data', type=str, default='../data/wikitext-103', +# help='location of the data corpus') +parser.add_argument('--model_name', type=str, default='transfo-xl-wt103', + choices=['transfo-xl-wt103'], #, 'lm1b', 'enwik8', 'text8'], + help='pretrained model name') +parser.add_argument('--split', type=str, default='all', + choices=['all', 'valid', 'test'], + help='which split to evaluate') +parser.add_argument('--batch_size', type=int, default=10, + help='batch size') +parser.add_argument('--tgt_len', type=int, default=5, + help='number of tokens to predict') +parser.add_argument('--ext_len', type=int, default=0, + help='length of the extended context') +parser.add_argument('--mem_len', type=int, default=0, + help='length of the retained previous heads') +parser.add_argument('--clamp_len', type=int, default=-1, + help='max positional embedding index') +parser.add_argument('--cuda', action='store_true', + help='use CUDA') +parser.add_argument('--work_dir', type=str, required=True, + help='path to the work_dir') +parser.add_argument('--no_log', action='store_true', + help='do not log the eval result') +parser.add_argument('--same_length', action='store_true', + help='set same length attention with masking') +args = parser.parse_args() +assert args.ext_len >= 0, 'extended context length must be non-negative' + +device = torch.device("cuda" if args.cuda else "cpu") + +# Get logger +logging = get_logger(os.path.join(args.work_dir, 'log.txt'), + log_=not args.no_log) + +# Load dataset +corpus = TransfoXLCorpus.from_pretrained(args.model_name) +ntokens = len(corpus.vocab) + +va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, + device=device, ext_len=args.ext_len) +te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, + device=device, ext_len=args.ext_len) + +# Load the best saved model. +# with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f: +# model = torch.load(f) +# model.backward_compatible() +model = TransfoXLModel.from_pretrained(args.model_name) +model = model.to(device) + +logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( + args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) + +model.reset_length(args.tgt_len, args.ext_len, args.mem_len) +if args.clamp_len > 0: + model.clamp_len = args.clamp_len +if args.same_length: + model.same_length = True + +############################################################################### +# Evaluation code +############################################################################### +def evaluate(eval_iter): + # Turn on evaluation mode which disables dropout. + model.eval() + total_len, total_loss = 0, 0. + start_time = time.time() + with torch.no_grad(): + mems = tuple() + for idx, (data, target, seq_len) in enumerate(eval_iter): + ret = model(data, target, *mems) + loss, mems = ret[0], ret[1:] + loss = loss.mean() + total_loss += seq_len * loss.item() + total_len += seq_len + total_time = time.time() - start_time + logging('Time : {:.2f}s, {:.2f}ms/segment'.format( + total_time, 1000 * total_time / (idx+1))) + return total_loss / total_len + +# Run on test data. +if args.split == 'all': + test_loss = evaluate(te_iter) + valid_loss = evaluate(va_iter) +elif args.split == 'valid': + valid_loss = evaluate(va_iter) + test_loss = None +elif args.split == 'test': + test_loss = evaluate(te_iter) + valid_loss = None + +def format_log(loss, split): + if args.dataset in ['enwik8', 'text8']: + log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format( + split, loss, loss / math.log(2)) + else: + log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( + split, loss, math.exp(loss)) + return log_str + +log_str = '' +if valid_loss is not None: + log_str += format_log(valid_loss, 'valid') +if test_loss is not None: + log_str += format_log(test_loss, 'test') + +logging('=' * 100) +logging(log_str) +logging('=' * 100) diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py index 0a9e41266dd..85f2422af64 100644 --- a/pytorch_pretrained_bert/__init__.py +++ b/pytorch_pretrained_bert/__init__.py @@ -1,12 +1,14 @@ __version__ = "0.5.0" from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .tokenization_openai import OpenAIGPTTokenizer +from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) from .modeling import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering) from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) +from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel) from .optimization import BertAdam from .optimization_openai import OpenAIAdam from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE diff --git a/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py index 861c26280dd..6962481adc6 100755 --- a/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py +++ b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py @@ -12,23 +12,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Convert OpenAI GPT checkpoint.""" +"""Convert Transformer XL checkpoint and datasets.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os -import re +import sys import argparse +import pickle + import tensorflow as tf import torch import numpy as np from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME +from pytorch_pretrained_bert.tokenization_transfo_xl import VOCAB_NAME, CORPUS_NAME + +# We do this to be able to load the python 2 datasets pickles +# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 +import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils +data_utils.Vocab = data_utils.TransfoXLTokenizer +data_utils.Corpus = data_utils.TransfoXLCorpus +sys.modules['data_utils'] = data_utils +sys.modules['vocabulary'] = data_utils def build_tf_to_pytorch_map(model, config): - """ A map of modules from TF to PyTorch """ + """ A map of modules from TF to PyTorch. + This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible. + """ tf_to_pt_map = {} # Embeddings cutoffs for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)): @@ -95,88 +108,108 @@ def build_tf_to_pytorch_map(model, config): def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, transfo_xl_config_file, - pytorch_dump_folder_path): - config_path = os.path.abspath(transfo_xl_config_file) - tf_path = os.path.abspath(tf_checkpoint_path) + pytorch_dump_folder_path, + transfo_xl_dataset_file): + if transfo_xl_dataset_file: + with open(transfo_xl_dataset_file, "rb") as fp: + corpus = pickle.load(fp, encoding="latin1") + # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) + pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME + print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) + torch.save(corpus.vocab.__dict__, pytorch_vocab_dump_path) - print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) - # Initialise PyTorch model - # Construct model - if transfo_xl_config_file == "": - config = TransfoXLConfig() - else: - config = TransfoXLConfig(transfo_xl_config_file) - print("Building PyTorch model from configuration: {}".format(str(config))) - model = TransfoXLModel(config) + corpus_dict_no_vocab = corpus.__dict__ + corpus_dict_no_vocab.pop('vocab', None) + pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME + print("Save dataset to {}".format(pytorch_dataset_dump_path)) + torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) - # Build TF to PyTorch weights loading map - tf_to_pt_map = build_tf_to_pytorch_map(model.transformer, config) + if tf_checkpoint_path: + config_path = os.path.abspath(transfo_xl_config_file) + tf_path = os.path.abspath(tf_checkpoint_path) - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - tf_weights = {} - for name, shape in init_vars: - print("Loading TF weight {} with shape {}".format(name, shape)) - array = tf.train.load_variable(tf_path, name) - tf_weights[name] = array + print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) + # Initialise PyTorch model + # Construct model + if transfo_xl_config_file == "": + config = TransfoXLConfig() + else: + config = TransfoXLConfig(transfo_xl_config_file) + print("Building PyTorch model from configuration: {}".format(str(config))) + model = TransfoXLModel(config) - for name, pointer in tf_to_pt_map.items(): - assert name in tf_weights - array = tf_weights[name] - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if 'kernel' in name or 'proj_W' in name: - array = np.transpose(array) - if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1: - # Here we will split the TF weigths - assert len(pointer) == array.shape[0] - for i, p_i in enumerate(pointer): - arr_i = array[i, ...] - try: - assert p_i.shape == arr_i.shape - except AssertionError as e: - e.args += (p_i.shape, arr_i.shape) - raise - print("Initialize PyTorch weight {} for layer {}".format(name, i)) - p_i.data = torch.from_numpy(arr_i) - continue - try: - assert pointer.shape == array.shape - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - print("Initialize PyTorch weight {}".format(name)) - pointer.data = torch.from_numpy(array) + # Build TF to PyTorch weights loading map + tf_to_pt_map = build_tf_to_pytorch_map(model.transformer, config) - # Save pytorch-model - pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME - pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME - print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) - torch.save(model.state_dict(), pytorch_weights_dump_path) - print("Save configuration file to {}".format(pytorch_config_dump_path)) - with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: - f.write(config.to_json_string()) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + tf_weights = {} + for name, shape in init_vars: + print("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + tf_weights[name] = array + + for name, pointer in tf_to_pt_map.items(): + assert name in tf_weights + array = tf_weights[name] + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if 'kernel' in name or 'proj_W' in name: + array = np.transpose(array) + if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1: + # Here we will split the TF weigths + assert len(pointer) == array.shape[0] + for i, p_i in enumerate(pointer): + arr_i = array[i, ...] + try: + assert p_i.shape == arr_i.shape + except AssertionError as e: + e.args += (p_i.shape, arr_i.shape) + raise + print("Initialize PyTorch weight {} for layer {}".format(name, i)) + p_i.data = torch.from_numpy(arr_i) + continue + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + print("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + + # Save pytorch-model + pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME + pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME + print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) + torch.save(model.state_dict(), pytorch_weights_dump_path) + print("Save configuration file to {}".format(pytorch_config_dump_path)) + with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) if __name__ == "__main__": parser = argparse.ArgumentParser() ## Required parameters - parser.add_argument("--tf_checkpoint_path", - default = None, - type = str, - required = True, - help = "Path the TensorFlow checkpoint path.") parser.add_argument("--pytorch_dump_folder_path", default = None, type = str, required = True, - help = "Path to the output PyTorch model.") + help = "Path to the folder to store the PyTorch model or dataset/vocab.") + parser.add_argument("--tf_checkpoint_path", + default = "", + type = str, + help = "An optional path to a TensorFlow checkpoint path to be converted.") parser.add_argument("--transfo_xl_config_file", default = "", type = str, - help = "The config json file corresponding to the pre-trained BERT model. \n" + help = "An optional config json file corresponding to the pre-trained BERT model. \n" "This specifies the model architecture.") + parser.add_argument("--transfo_xl_dataset_file", + default = "", + type = str, + help = "An optional dataset file to be converted in a vocabulary.") args = parser.parse_args() convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, args.transfo_xl_config_file, - args.pytorch_dump_folder_path) + args.pytorch_dump_folder_path, + args.transfo_xl_dataset_file) diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index a7a9ca2e5be..5b80f045a4b 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Transformer XL model. - Directly adapted from https://github.com/kimiyoung/transformer-xl. + Adapted from https://github.com/kimiyoung/transformer-xl. In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py """ @@ -40,7 +40,7 @@ from .file_utils import cached_path logger = logging.getLogger(__name__) PRETRAINED_MODEL_ARCHIVE_MAP = { - 'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl.tar.gz", + 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103.tar.gz", } CONFIG_NAME = 'transfo_xl_config.json' WEIGHTS_NAME = 'pytorch_model.bin' @@ -59,12 +59,13 @@ class TransfoXLConfig(object): div_val=4, pre_lnorm=False, n_layer=18, - tgt_len=256, + tgt_len=128, ext_len=0, - mem_len=256, - same_length=False, + mem_len=1600, + clamp_len=1000, + same_length=True, + proj_share_all_but_first=True, attn_type=0, - clamp_len=-1, sample_softmax=-1, adaptive=True, tie_weight=True, @@ -93,6 +94,7 @@ class TransfoXLConfig(object): ext_len: length of the extended context mem_len: length of the retained previous heads same_length: use the same attn length for all tokens + proj_share_all_but_first: True to share all but first projs, False not to share. attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al. clamp_len: use the same pos embeddings after clamp_len sample_softmax: number of samples in sampled softmax @@ -118,7 +120,10 @@ class TransfoXLConfig(object): self.cutoffs = [] self.cutoffs.extend(cutoffs) self.tie_weight = tie_weight - self.tie_projs = [False] + [True] * len(self.cutoffs) + if proj_share_all_but_first: + self.tie_projs = [False] + [True] * len(self.cutoffs) + else: + self.tie_projs = [False] + [False] * len(self.cutoffs) self.d_model = d_model self.d_embed = d_embed self.d_head = d_head @@ -423,7 +428,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head #### compute attention score - rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head + rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head rr_head_q = w_head_q + self.r_r_bias @@ -915,21 +920,25 @@ class MemTransformerLM(nn.Module): return core_out, new_mems - def forward(self, data, target, *mems): + def forward(self, data, target=None, *mems): # nn.DataParallel does not allow size(0) tensors to be broadcasted. # So, have to initialize size(0) mems inside the model forward. # Moreover, have to return new_mems to allow nn.DataParallel to piece # them together. if not mems: mems = self.init_mems() - tgt_len = target.size(0) hidden, new_mems = self._forward(data, mems=mems) + if target is None: + if new_mems is None: + return [hidden] + else: + return [hidden] + new_mems + tgt_len = target.size(0) pred_hid = hidden[-tgt_len:] if self.sample_softmax > 0 and self.training: assert self.tie_weight - logit = sample_logits(self.word_emb, - self.out_layer.bias, target, pred_hid, self.sampler) + logit = sample_logits(self.word_emb, self.out_layer.bias, target, pred_hid, self.sampler) loss = -F.log_softmax(logit, -1)[:, :, 0] else: loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1)) @@ -1010,7 +1019,7 @@ class TransfoXLPreTrainedModel(nn.Module): pass @classmethod - def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None, + def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): """ Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict. @@ -1100,7 +1109,7 @@ class TransfoXLPreTrainedModel(nn.Module): for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') - load(model.transformer if hasattr(model, 'transformer') else model, prefix='') + # load(model.transformer if hasattr(model, 'transformer') else model, prefix='') if len(missing_keys) > 0: logger.info("Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys)) @@ -1110,9 +1119,6 @@ class TransfoXLPreTrainedModel(nn.Module): if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( model.__class__.__name__, "\n\t".join(error_msgs))) - # Add additional embeddings for special tokens if needed - if num_special_tokens != config.n_special: - model.set_num_special_tokens(num_special_tokens) if tempdir: # Clean up temp dir shutil.rmtree(tempdir) diff --git a/pytorch_pretrained_bert/tokenization_transfo_xl.py b/pytorch_pretrained_bert/tokenization_transfo_xl.py index 1d278abcb27..a411c267b98 100644 --- a/pytorch_pretrained_bert/tokenization_transfo_xl.py +++ b/pytorch_pretrained_bert/tokenization_transfo_xl.py @@ -14,15 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Tokenization classes for Transformer XL model. - Directly adapted from https://github.com/kimiyoung/transformer-xl. + Adapted from https://github.com/kimiyoung/transformer-xl. """ import os -import re -import json -from tqdm import tqdm +import glob import logging import pickle +import torch from collections import Counter, OrderedDict from .file_utils import cached_path @@ -30,16 +29,14 @@ from .file_utils import cached_path logger = logging.getLogger(__name__) PRETRAINED_VOCAB_ARCHIVE_MAP = { - 'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", + 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin", } -PRETRAINED_MERGES_ARCHIVE_MAP = { - 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", +VOCAB_NAME = 'vocab.bin' + +PRETRAINED_CORPUS_ARCHIVE_MAP = { + 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin", } -PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { - 'openai-gpt': 512, -} -VOCAB_NAME = 'vocab.json' -MERGES_NAME = 'merges.txt' +CORPUS_NAME = 'corpus.bin' class TransfoXLTokenizer(object): """ @@ -49,43 +46,36 @@ class TransfoXLTokenizer(object): def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): """ Instantiate a TransfoXLTokenizer. - Download and cache the vocabulary if needed. + The TransfoXLTokenizer. """ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] - merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] else: vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) - merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) - resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) except FileNotFoundError: logger.error( "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find files {} and {} " + "We assumed '{}' was a path or url but couldn't find files {} " "at this path or url.".format( pretrained_model_name_or_path, ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), pretrained_model_name_or_path, - vocab_file, merges_file)) + vocab_file)) return None - if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: + if resolved_vocab_file == vocab_file: logger.info("loading vocabulary file {}".format(vocab_file)) - logger.info("loading merges file {}".format(merges_file)) else: logger.info("loading vocabulary file {} from cache at {}".format( vocab_file, resolved_vocab_file)) - logger.info("loading merges file {} from cache at {}".format( - merges_file, resolved_merges_file)) - if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: - # if we're using a pretrained model, ensure the tokenizer wont index sequences longer - # than the number of positional embeddings - max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] - kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. - tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) + tokenizer = cls(*inputs, **kwargs) + vocab_dict = torch.load(resolved_vocab_file) + for key, value in vocab_dict.items(): + tokenizer.__dict__[key] = value return tokenizer def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, @@ -418,10 +408,53 @@ class LMMultiFileIterator(LMShuffledIterator): yield batch -class Corpus(object): - def __init__(self, path, dataset, *args, **kwargs): +class TransfoXLCorpus(object): + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a pre-processed corpus. + """ + vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP: + corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME) + # redirect to the cache, if necessary + try: + resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir) + except FileNotFoundError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find files {} " + "at this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + pretrained_model_name_or_path, + corpus_file)) + return None + if resolved_corpus_file == corpus_file: + logger.info("loading corpus file {}".format(corpus_file)) + else: + logger.info("loading corpus file {} from cache at {}".format( + corpus_file, resolved_corpus_file)) + + # Instantiate tokenizer. + corpus = cls(*inputs, **kwargs) + corpus_dict = torch.load(resolved_corpus_file) + for key, value in corpus_dict.items(): + corpus.__dict__[key] = value + corpus.vocab = vocab + return corpus + + def __init__(self, *args, **kwargs): + self.vocab = TransfoXLTokenizer(*args, **kwargs) + self.dataset = None + self.train = None + self.valid = None + self.test = None + + def build_corpus(self, path, dataset): self.dataset = dataset - self.vocab = Vocab(*args, **kwargs) if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']: self.vocab.count_file(os.path.join(path, 'train.txt')) @@ -443,20 +476,20 @@ class Corpus(object): os.path.join(path, 'train.txt'), ordered=True) self.valid = self.vocab.encode_file( os.path.join(path, 'valid.txt'), ordered=True) - self.test = self.vocab.encode_file( + self.test = self.vocab.encode_file( os.path.join(path, 'test.txt'), ordered=True) elif self.dataset in ['enwik8', 'text8']: self.train = self.vocab.encode_file( os.path.join(path, 'train.txt'), ordered=True, add_eos=False) self.valid = self.vocab.encode_file( os.path.join(path, 'valid.txt'), ordered=True, add_eos=False) - self.test = self.vocab.encode_file( + self.test = self.vocab.encode_file( os.path.join(path, 'test.txt'), ordered=True, add_eos=False) elif self.dataset == 'lm1b': self.train = train_paths self.valid = self.vocab.encode_file( os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True) - self.test = self.vocab.encode_file( + self.test = self.vocab.encode_file( os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True) def get_iterator(self, split, *args, **kwargs): @@ -502,7 +535,7 @@ def get_lm_corpus(datadir, dataset): elif dataset in ['enwik8', 'text8']: pass - corpus = Corpus(datadir, dataset, **kwargs) + corpus = TransfoXLCorpus(datadir, dataset, **kwargs) torch.save(corpus, fn) return corpus