improved corpus and tokenization conversion - added evaluation script

This commit is contained in:
thomwolf 2019-01-15 23:17:46 +01:00
parent 7d03c53718
commit a69ec2c722
5 changed files with 344 additions and 119 deletions

151
examples/eval_transfo_xl.py Normal file
View File

@ -0,0 +1,151 @@
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model evaluation script.
Adapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
"""
import os
import sys
import functools
import argparse
import time
import math
import torch
from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus
def logging(s, log_path, print_=True, log_=True):
if print_:
print(s)
if log_:
with open(log_path, 'a+') as f_log:
f_log.write(s + '\n')
def get_logger(log_path, **kwargs):
return functools.partial(logging, log_path=log_path, **kwargs)
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
# help='location of the data corpus')
parser.add_argument('--model_name', type=str, default='transfo-xl-wt103',
choices=['transfo-xl-wt103'], #, 'lm1b', 'enwik8', 'text8'],
help='pretrained model name')
parser.add_argument('--split', type=str, default='all',
choices=['all', 'valid', 'test'],
help='which split to evaluate')
parser.add_argument('--batch_size', type=int, default=10,
help='batch size')
parser.add_argument('--tgt_len', type=int, default=5,
help='number of tokens to predict')
parser.add_argument('--ext_len', type=int, default=0,
help='length of the extended context')
parser.add_argument('--mem_len', type=int, default=0,
help='length of the retained previous heads')
parser.add_argument('--clamp_len', type=int, default=-1,
help='max positional embedding index')
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--work_dir', type=str, required=True,
help='path to the work_dir')
parser.add_argument('--no_log', action='store_true',
help='do not log the eval result')
parser.add_argument('--same_length', action='store_true',
help='set same length attention with masking')
args = parser.parse_args()
assert args.ext_len >= 0, 'extended context length must be non-negative'
device = torch.device("cuda" if args.cuda else "cpu")
# Get logger
logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
log_=not args.no_log)
# Load dataset
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
ntokens = len(corpus.vocab)
va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
# Load the best saved model.
# with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
# model = torch.load(f)
# model.backward_compatible()
model = TransfoXLModel.from_pretrained(args.model_name)
model = model.to(device)
logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
if args.clamp_len > 0:
model.clamp_len = args.clamp_len
if args.same_length:
model.same_length = True
###############################################################################
# Evaluation code
###############################################################################
def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout.
model.eval()
total_len, total_loss = 0, 0.
start_time = time.time()
with torch.no_grad():
mems = tuple()
for idx, (data, target, seq_len) in enumerate(eval_iter):
ret = model(data, target, *mems)
loss, mems = ret[0], ret[1:]
loss = loss.mean()
total_loss += seq_len * loss.item()
total_len += seq_len
total_time = time.time() - start_time
logging('Time : {:.2f}s, {:.2f}ms/segment'.format(
total_time, 1000 * total_time / (idx+1)))
return total_loss / total_len
# Run on test data.
if args.split == 'all':
test_loss = evaluate(te_iter)
valid_loss = evaluate(va_iter)
elif args.split == 'valid':
valid_loss = evaluate(va_iter)
test_loss = None
elif args.split == 'test':
test_loss = evaluate(te_iter)
valid_loss = None
def format_log(loss, split):
if args.dataset in ['enwik8', 'text8']:
log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
split, loss, loss / math.log(2))
else:
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
split, loss, math.exp(loss))
return log_str
log_str = ''
if valid_loss is not None:
log_str += format_log(valid_loss, 'valid')
if test_loss is not None:
log_str += format_log(test_loss, 'test')
logging('=' * 100)
logging(log_str)
logging('=' * 100)

View File

@ -1,12 +1,14 @@
__version__ = "0.5.0"
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering)
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel)
from .optimization import BertAdam
from .optimization_openai import OpenAIAdam
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE

View File

@ -12,23 +12,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert OpenAI GPT checkpoint."""
"""Convert Transformer XL checkpoint and datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
import sys
import argparse
import pickle
import tensorflow as tf
import torch
import numpy as np
from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME
from pytorch_pretrained_bert.tokenization_transfo_xl import VOCAB_NAME, CORPUS_NAME
# We do this to be able to load the python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
data_utils.Vocab = data_utils.TransfoXLTokenizer
data_utils.Corpus = data_utils.TransfoXLCorpus
sys.modules['data_utils'] = data_utils
sys.modules['vocabulary'] = data_utils
def build_tf_to_pytorch_map(model, config):
""" A map of modules from TF to PyTorch """
""" A map of modules from TF to PyTorch.
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
"""
tf_to_pt_map = {}
# Embeddings cutoffs
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
@ -95,88 +108,108 @@ def build_tf_to_pytorch_map(model, config):
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
transfo_xl_config_file,
pytorch_dump_folder_path):
config_path = os.path.abspath(transfo_xl_config_file)
tf_path = os.path.abspath(tf_checkpoint_path)
pytorch_dump_folder_path,
transfo_xl_dataset_file):
if transfo_xl_dataset_file:
with open(transfo_xl_dataset_file, "rb") as fp:
corpus = pickle.load(fp, encoding="latin1")
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
torch.save(corpus.vocab.__dict__, pytorch_vocab_dump_path)
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
# Initialise PyTorch model
# Construct model
if transfo_xl_config_file == "":
config = TransfoXLConfig()
else:
config = TransfoXLConfig(transfo_xl_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = TransfoXLModel(config)
corpus_dict_no_vocab = corpus.__dict__
corpus_dict_no_vocab.pop('vocab', None)
pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
print("Save dataset to {}".format(pytorch_dataset_dump_path))
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
# Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_to_pytorch_map(model.transformer, config)
if tf_checkpoint_path:
config_path = os.path.abspath(transfo_xl_config_file)
tf_path = os.path.abspath(tf_checkpoint_path)
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
tf_weights = {}
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
tf_weights[name] = array
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
# Initialise PyTorch model
# Construct model
if transfo_xl_config_file == "":
config = TransfoXLConfig()
else:
config = TransfoXLConfig(transfo_xl_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = TransfoXLModel(config)
for name, pointer in tf_to_pt_map.items():
assert name in tf_weights
array = tf_weights[name]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if 'kernel' in name or 'proj_W' in name:
array = np.transpose(array)
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
# Here we will split the TF weigths
assert len(pointer) == array.shape[0]
for i, p_i in enumerate(pointer):
arr_i = array[i, ...]
try:
assert p_i.shape == arr_i.shape
except AssertionError as e:
e.args += (p_i.shape, arr_i.shape)
raise
print("Initialize PyTorch weight {} for layer {}".format(name, i))
p_i.data = torch.from_numpy(arr_i)
continue
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
# Build TF to PyTorch weights loading map
tf_to_pt_map = build_tf_to_pytorch_map(model.transformer, config)
# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
tf_weights = {}
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
tf_weights[name] = array
for name, pointer in tf_to_pt_map.items():
assert name in tf_weights
array = tf_weights[name]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if 'kernel' in name or 'proj_W' in name:
array = np.transpose(array)
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
# Here we will split the TF weigths
assert len(pointer) == array.shape[0]
for i, p_i in enumerate(pointer):
arr_i = array[i, ...]
try:
assert p_i.shape == arr_i.shape
except AssertionError as e:
e.args += (p_i.shape, arr_i.shape)
raise
print("Initialize PyTorch weight {} for layer {}".format(name, i))
p_i.data = torch.from_numpy(arr_i)
continue
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
help = "Path to the folder to store the PyTorch model or dataset/vocab.")
parser.add_argument("--tf_checkpoint_path",
default = "",
type = str,
help = "An optional path to a TensorFlow checkpoint path to be converted.")
parser.add_argument("--transfo_xl_config_file",
default = "",
type = str,
help = "The config json file corresponding to the pre-trained BERT model. \n"
help = "An optional config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--transfo_xl_dataset_file",
default = "",
type = str,
help = "An optional dataset file to be converted in a vocabulary.")
args = parser.parse_args()
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.transfo_xl_config_file,
args.pytorch_dump_folder_path)
args.pytorch_dump_folder_path,
args.transfo_xl_dataset_file)

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model.
Directly adapted from https://github.com/kimiyoung/transformer-xl.
Adapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
"""
@ -40,7 +40,7 @@ from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl.tar.gz",
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103.tar.gz",
}
CONFIG_NAME = 'transfo_xl_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
@ -59,12 +59,13 @@ class TransfoXLConfig(object):
div_val=4,
pre_lnorm=False,
n_layer=18,
tgt_len=256,
tgt_len=128,
ext_len=0,
mem_len=256,
same_length=False,
mem_len=1600,
clamp_len=1000,
same_length=True,
proj_share_all_but_first=True,
attn_type=0,
clamp_len=-1,
sample_softmax=-1,
adaptive=True,
tie_weight=True,
@ -93,6 +94,7 @@ class TransfoXLConfig(object):
ext_len: length of the extended context
mem_len: length of the retained previous heads
same_length: use the same attn length for all tokens
proj_share_all_but_first: True to share all but first projs, False not to share.
attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.
clamp_len: use the same pos embeddings after clamp_len
sample_softmax: number of samples in sampled softmax
@ -118,7 +120,10 @@ class TransfoXLConfig(object):
self.cutoffs = []
self.cutoffs.extend(cutoffs)
self.tie_weight = tie_weight
self.tie_projs = [False] + [True] * len(self.cutoffs)
if proj_share_all_but_first:
self.tie_projs = [False] + [True] * len(self.cutoffs)
else:
self.tie_projs = [False] + [False] * len(self.cutoffs)
self.d_model = d_model
self.d_embed = d_embed
self.d_head = d_head
@ -423,7 +428,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
#### compute attention score
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
rr_head_q = w_head_q + self.r_r_bias
@ -915,21 +920,25 @@ class MemTransformerLM(nn.Module):
return core_out, new_mems
def forward(self, data, target, *mems):
def forward(self, data, target=None, *mems):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
if not mems: mems = self.init_mems()
tgt_len = target.size(0)
hidden, new_mems = self._forward(data, mems=mems)
if target is None:
if new_mems is None:
return [hidden]
else:
return [hidden] + new_mems
tgt_len = target.size(0)
pred_hid = hidden[-tgt_len:]
if self.sample_softmax > 0 and self.training:
assert self.tie_weight
logit = sample_logits(self.word_emb,
self.out_layer.bias, target, pred_hid, self.sampler)
logit = sample_logits(self.word_emb, self.out_layer.bias, target, pred_hid, self.sampler)
loss = -F.log_softmax(logit, -1)[:, :, 0]
else:
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
@ -1010,7 +1019,7 @@ class TransfoXLPreTrainedModel(nn.Module):
pass
@classmethod
def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None,
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
*inputs, **kwargs):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
@ -1100,7 +1109,7 @@ class TransfoXLPreTrainedModel(nn.Module):
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
# load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
@ -1110,9 +1119,6 @@ class TransfoXLPreTrainedModel(nn.Module):
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
# Add additional embeddings for special tokens if needed
if num_special_tokens != config.n_special:
model.set_num_special_tokens(num_special_tokens)
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)

View File

@ -14,15 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization classes for Transformer XL model.
Directly adapted from https://github.com/kimiyoung/transformer-xl.
Adapted from https://github.com/kimiyoung/transformer-xl.
"""
import os
import re
import json
from tqdm import tqdm
import glob
import logging
import pickle
import torch
from collections import Counter, OrderedDict
from .file_utils import cached_path
@ -30,16 +29,14 @@ from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
VOCAB_NAME = 'vocab.bin'
PRETRAINED_CORPUS_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'openai-gpt': 512,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
CORPUS_NAME = 'corpus.bin'
class TransfoXLTokenizer(object):
"""
@ -49,43 +46,36 @@ class TransfoXLTokenizer(object):
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a TransfoXLTokenizer.
Download and cache the vocabulary if needed.
The TransfoXLTokenizer.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
vocab_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
tokenizer = cls(*inputs, **kwargs)
vocab_dict = torch.load(resolved_vocab_file)
for key, value in vocab_dict.items():
tokenizer.__dict__[key] = value
return tokenizer
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
@ -418,10 +408,53 @@ class LMMultiFileIterator(LMShuffledIterator):
yield batch
class Corpus(object):
def __init__(self, path, dataset, *args, **kwargs):
class TransfoXLCorpus(object):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a pre-processed corpus.
"""
vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:
corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
# redirect to the cache, if necessary
try:
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
corpus_file))
return None
if resolved_corpus_file == corpus_file:
logger.info("loading corpus file {}".format(corpus_file))
else:
logger.info("loading corpus file {} from cache at {}".format(
corpus_file, resolved_corpus_file))
# Instantiate tokenizer.
corpus = cls(*inputs, **kwargs)
corpus_dict = torch.load(resolved_corpus_file)
for key, value in corpus_dict.items():
corpus.__dict__[key] = value
corpus.vocab = vocab
return corpus
def __init__(self, *args, **kwargs):
self.vocab = TransfoXLTokenizer(*args, **kwargs)
self.dataset = None
self.train = None
self.valid = None
self.test = None
def build_corpus(self, path, dataset):
self.dataset = dataset
self.vocab = Vocab(*args, **kwargs)
if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
self.vocab.count_file(os.path.join(path, 'train.txt'))
@ -443,20 +476,20 @@ class Corpus(object):
os.path.join(path, 'train.txt'), ordered=True)
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=True)
self.test = self.vocab.encode_file(
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=True)
elif self.dataset in ['enwik8', 'text8']:
self.train = self.vocab.encode_file(
os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
self.test = self.vocab.encode_file(
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
elif self.dataset == 'lm1b':
self.train = train_paths
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
self.test = self.vocab.encode_file(
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
def get_iterator(self, split, *args, **kwargs):
@ -502,7 +535,7 @@ def get_lm_corpus(datadir, dataset):
elif dataset in ['enwik8', 'text8']:
pass
corpus = Corpus(datadir, dataset, **kwargs)
corpus = TransfoXLCorpus(datadir, dataset, **kwargs)
torch.save(corpus, fn)
return corpus