mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-29 17:22:25 +06:00
improved corpus and tokenization conversion - added evaluation script
This commit is contained in:
parent
7d03c53718
commit
a69ec2c722
151
examples/eval_transfo_xl.py
Normal file
151
examples/eval_transfo_xl.py
Normal file
@ -0,0 +1,151 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Transformer XL model evaluation script.
|
||||
Adapted from https://github.com/kimiyoung/transformer-xl.
|
||||
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import functools
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus
|
||||
|
||||
def logging(s, log_path, print_=True, log_=True):
|
||||
if print_:
|
||||
print(s)
|
||||
if log_:
|
||||
with open(log_path, 'a+') as f_log:
|
||||
f_log.write(s + '\n')
|
||||
|
||||
def get_logger(log_path, **kwargs):
|
||||
return functools.partial(logging, log_path=log_path, **kwargs)
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
|
||||
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
|
||||
# help='location of the data corpus')
|
||||
parser.add_argument('--model_name', type=str, default='transfo-xl-wt103',
|
||||
choices=['transfo-xl-wt103'], #, 'lm1b', 'enwik8', 'text8'],
|
||||
help='pretrained model name')
|
||||
parser.add_argument('--split', type=str, default='all',
|
||||
choices=['all', 'valid', 'test'],
|
||||
help='which split to evaluate')
|
||||
parser.add_argument('--batch_size', type=int, default=10,
|
||||
help='batch size')
|
||||
parser.add_argument('--tgt_len', type=int, default=5,
|
||||
help='number of tokens to predict')
|
||||
parser.add_argument('--ext_len', type=int, default=0,
|
||||
help='length of the extended context')
|
||||
parser.add_argument('--mem_len', type=int, default=0,
|
||||
help='length of the retained previous heads')
|
||||
parser.add_argument('--clamp_len', type=int, default=-1,
|
||||
help='max positional embedding index')
|
||||
parser.add_argument('--cuda', action='store_true',
|
||||
help='use CUDA')
|
||||
parser.add_argument('--work_dir', type=str, required=True,
|
||||
help='path to the work_dir')
|
||||
parser.add_argument('--no_log', action='store_true',
|
||||
help='do not log the eval result')
|
||||
parser.add_argument('--same_length', action='store_true',
|
||||
help='set same length attention with masking')
|
||||
args = parser.parse_args()
|
||||
assert args.ext_len >= 0, 'extended context length must be non-negative'
|
||||
|
||||
device = torch.device("cuda" if args.cuda else "cpu")
|
||||
|
||||
# Get logger
|
||||
logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
|
||||
log_=not args.no_log)
|
||||
|
||||
# Load dataset
|
||||
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
|
||||
ntokens = len(corpus.vocab)
|
||||
|
||||
va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
|
||||
device=device, ext_len=args.ext_len)
|
||||
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
|
||||
device=device, ext_len=args.ext_len)
|
||||
|
||||
# Load the best saved model.
|
||||
# with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
|
||||
# model = torch.load(f)
|
||||
# model.backward_compatible()
|
||||
model = TransfoXLModel.from_pretrained(args.model_name)
|
||||
model = model.to(device)
|
||||
|
||||
logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
|
||||
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
|
||||
|
||||
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
|
||||
if args.clamp_len > 0:
|
||||
model.clamp_len = args.clamp_len
|
||||
if args.same_length:
|
||||
model.same_length = True
|
||||
|
||||
###############################################################################
|
||||
# Evaluation code
|
||||
###############################################################################
|
||||
def evaluate(eval_iter):
|
||||
# Turn on evaluation mode which disables dropout.
|
||||
model.eval()
|
||||
total_len, total_loss = 0, 0.
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
mems = tuple()
|
||||
for idx, (data, target, seq_len) in enumerate(eval_iter):
|
||||
ret = model(data, target, *mems)
|
||||
loss, mems = ret[0], ret[1:]
|
||||
loss = loss.mean()
|
||||
total_loss += seq_len * loss.item()
|
||||
total_len += seq_len
|
||||
total_time = time.time() - start_time
|
||||
logging('Time : {:.2f}s, {:.2f}ms/segment'.format(
|
||||
total_time, 1000 * total_time / (idx+1)))
|
||||
return total_loss / total_len
|
||||
|
||||
# Run on test data.
|
||||
if args.split == 'all':
|
||||
test_loss = evaluate(te_iter)
|
||||
valid_loss = evaluate(va_iter)
|
||||
elif args.split == 'valid':
|
||||
valid_loss = evaluate(va_iter)
|
||||
test_loss = None
|
||||
elif args.split == 'test':
|
||||
test_loss = evaluate(te_iter)
|
||||
valid_loss = None
|
||||
|
||||
def format_log(loss, split):
|
||||
if args.dataset in ['enwik8', 'text8']:
|
||||
log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
|
||||
split, loss, loss / math.log(2))
|
||||
else:
|
||||
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
|
||||
split, loss, math.exp(loss))
|
||||
return log_str
|
||||
|
||||
log_str = ''
|
||||
if valid_loss is not None:
|
||||
log_str += format_log(valid_loss, 'valid')
|
||||
if test_loss is not None:
|
||||
log_str += format_log(test_loss, 'test')
|
||||
|
||||
logging('=' * 100)
|
||||
logging(log_str)
|
||||
logging('=' * 100)
|
@ -1,12 +1,14 @@
|
||||
__version__ = "0.5.0"
|
||||
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||
from .tokenization_openai import OpenAIGPTTokenizer
|
||||
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
|
||||
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||
BertForMaskedLM, BertForNextSentencePrediction,
|
||||
BertForSequenceClassification, BertForMultipleChoice,
|
||||
BertForTokenClassification, BertForQuestionAnswering)
|
||||
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
|
||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
|
||||
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel)
|
||||
from .optimization import BertAdam
|
||||
from .optimization_openai import OpenAIAdam
|
||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
|
@ -12,23 +12,36 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert OpenAI GPT checkpoint."""
|
||||
"""Convert Transformer XL checkpoint and datasets."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import argparse
|
||||
import pickle
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME
|
||||
from pytorch_pretrained_bert.tokenization_transfo_xl import VOCAB_NAME, CORPUS_NAME
|
||||
|
||||
# We do this to be able to load the python 2 datasets pickles
|
||||
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
||||
import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
|
||||
data_utils.Vocab = data_utils.TransfoXLTokenizer
|
||||
data_utils.Corpus = data_utils.TransfoXLCorpus
|
||||
sys.modules['data_utils'] = data_utils
|
||||
sys.modules['vocabulary'] = data_utils
|
||||
|
||||
def build_tf_to_pytorch_map(model, config):
|
||||
""" A map of modules from TF to PyTorch """
|
||||
""" A map of modules from TF to PyTorch.
|
||||
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
|
||||
"""
|
||||
tf_to_pt_map = {}
|
||||
# Embeddings cutoffs
|
||||
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
|
||||
@ -95,88 +108,108 @@ def build_tf_to_pytorch_map(model, config):
|
||||
|
||||
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||
transfo_xl_config_file,
|
||||
pytorch_dump_folder_path):
|
||||
config_path = os.path.abspath(transfo_xl_config_file)
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
pytorch_dump_folder_path,
|
||||
transfo_xl_dataset_file):
|
||||
if transfo_xl_dataset_file:
|
||||
with open(transfo_xl_dataset_file, "rb") as fp:
|
||||
corpus = pickle.load(fp, encoding="latin1")
|
||||
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
|
||||
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
|
||||
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
|
||||
torch.save(corpus.vocab.__dict__, pytorch_vocab_dump_path)
|
||||
|
||||
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
|
||||
# Initialise PyTorch model
|
||||
# Construct model
|
||||
if transfo_xl_config_file == "":
|
||||
config = TransfoXLConfig()
|
||||
else:
|
||||
config = TransfoXLConfig(transfo_xl_config_file)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = TransfoXLModel(config)
|
||||
corpus_dict_no_vocab = corpus.__dict__
|
||||
corpus_dict_no_vocab.pop('vocab', None)
|
||||
pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
|
||||
print("Save dataset to {}".format(pytorch_dataset_dump_path))
|
||||
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
|
||||
|
||||
# Build TF to PyTorch weights loading map
|
||||
tf_to_pt_map = build_tf_to_pytorch_map(model.transformer, config)
|
||||
if tf_checkpoint_path:
|
||||
config_path = os.path.abspath(transfo_xl_config_file)
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
tf_weights = {}
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
tf_weights[name] = array
|
||||
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
|
||||
# Initialise PyTorch model
|
||||
# Construct model
|
||||
if transfo_xl_config_file == "":
|
||||
config = TransfoXLConfig()
|
||||
else:
|
||||
config = TransfoXLConfig(transfo_xl_config_file)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = TransfoXLModel(config)
|
||||
|
||||
for name, pointer in tf_to_pt_map.items():
|
||||
assert name in tf_weights
|
||||
array = tf_weights[name]
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if 'kernel' in name or 'proj_W' in name:
|
||||
array = np.transpose(array)
|
||||
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
|
||||
# Here we will split the TF weigths
|
||||
assert len(pointer) == array.shape[0]
|
||||
for i, p_i in enumerate(pointer):
|
||||
arr_i = array[i, ...]
|
||||
try:
|
||||
assert p_i.shape == arr_i.shape
|
||||
except AssertionError as e:
|
||||
e.args += (p_i.shape, arr_i.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {} for layer {}".format(name, i))
|
||||
p_i.data = torch.from_numpy(arr_i)
|
||||
continue
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
# Build TF to PyTorch weights loading map
|
||||
tf_to_pt_map = build_tf_to_pytorch_map(model.transformer, config)
|
||||
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
tf_weights = {}
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
tf_weights[name] = array
|
||||
|
||||
for name, pointer in tf_to_pt_map.items():
|
||||
assert name in tf_weights
|
||||
array = tf_weights[name]
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if 'kernel' in name or 'proj_W' in name:
|
||||
array = np.transpose(array)
|
||||
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
|
||||
# Here we will split the TF weigths
|
||||
assert len(pointer) == array.shape[0]
|
||||
for i, p_i in enumerate(pointer):
|
||||
arr_i = array[i, ...]
|
||||
try:
|
||||
assert p_i.shape == arr_i.shape
|
||||
except AssertionError as e:
|
||||
e.args += (p_i.shape, arr_i.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {} for layer {}".format(name, i))
|
||||
p_i.data = torch.from_numpy(arr_i)
|
||||
continue
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--pytorch_dump_folder_path",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
help = "Path to the folder to store the PyTorch model or dataset/vocab.")
|
||||
parser.add_argument("--tf_checkpoint_path",
|
||||
default = "",
|
||||
type = str,
|
||||
help = "An optional path to a TensorFlow checkpoint path to be converted.")
|
||||
parser.add_argument("--transfo_xl_config_file",
|
||||
default = "",
|
||||
type = str,
|
||||
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
||||
help = "An optional config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--transfo_xl_dataset_file",
|
||||
default = "",
|
||||
type = str,
|
||||
help = "An optional dataset file to be converted in a vocabulary.")
|
||||
args = parser.parse_args()
|
||||
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||
args.transfo_xl_config_file,
|
||||
args.pytorch_dump_folder_path)
|
||||
args.pytorch_dump_folder_path,
|
||||
args.transfo_xl_dataset_file)
|
||||
|
@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Transformer XL model.
|
||||
Directly adapted from https://github.com/kimiyoung/transformer-xl.
|
||||
Adapted from https://github.com/kimiyoung/transformer-xl.
|
||||
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
|
||||
"""
|
||||
|
||||
@ -40,7 +40,7 @@ from .file_utils import cached_path
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl.tar.gz",
|
||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103.tar.gz",
|
||||
}
|
||||
CONFIG_NAME = 'transfo_xl_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
@ -59,12 +59,13 @@ class TransfoXLConfig(object):
|
||||
div_val=4,
|
||||
pre_lnorm=False,
|
||||
n_layer=18,
|
||||
tgt_len=256,
|
||||
tgt_len=128,
|
||||
ext_len=0,
|
||||
mem_len=256,
|
||||
same_length=False,
|
||||
mem_len=1600,
|
||||
clamp_len=1000,
|
||||
same_length=True,
|
||||
proj_share_all_but_first=True,
|
||||
attn_type=0,
|
||||
clamp_len=-1,
|
||||
sample_softmax=-1,
|
||||
adaptive=True,
|
||||
tie_weight=True,
|
||||
@ -93,6 +94,7 @@ class TransfoXLConfig(object):
|
||||
ext_len: length of the extended context
|
||||
mem_len: length of the retained previous heads
|
||||
same_length: use the same attn length for all tokens
|
||||
proj_share_all_but_first: True to share all but first projs, False not to share.
|
||||
attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.
|
||||
clamp_len: use the same pos embeddings after clamp_len
|
||||
sample_softmax: number of samples in sampled softmax
|
||||
@ -118,7 +120,10 @@ class TransfoXLConfig(object):
|
||||
self.cutoffs = []
|
||||
self.cutoffs.extend(cutoffs)
|
||||
self.tie_weight = tie_weight
|
||||
self.tie_projs = [False] + [True] * len(self.cutoffs)
|
||||
if proj_share_all_but_first:
|
||||
self.tie_projs = [False] + [True] * len(self.cutoffs)
|
||||
else:
|
||||
self.tie_projs = [False] + [False] * len(self.cutoffs)
|
||||
self.d_model = d_model
|
||||
self.d_embed = d_embed
|
||||
self.d_head = d_head
|
||||
@ -423,7 +428,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
||||
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
|
||||
|
||||
#### compute attention score
|
||||
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
|
||||
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
|
||||
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
|
||||
|
||||
rr_head_q = w_head_q + self.r_r_bias
|
||||
@ -915,21 +920,25 @@ class MemTransformerLM(nn.Module):
|
||||
|
||||
return core_out, new_mems
|
||||
|
||||
def forward(self, data, target, *mems):
|
||||
def forward(self, data, target=None, *mems):
|
||||
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
|
||||
# So, have to initialize size(0) mems inside the model forward.
|
||||
# Moreover, have to return new_mems to allow nn.DataParallel to piece
|
||||
# them together.
|
||||
if not mems: mems = self.init_mems()
|
||||
|
||||
tgt_len = target.size(0)
|
||||
hidden, new_mems = self._forward(data, mems=mems)
|
||||
if target is None:
|
||||
if new_mems is None:
|
||||
return [hidden]
|
||||
else:
|
||||
return [hidden] + new_mems
|
||||
|
||||
tgt_len = target.size(0)
|
||||
pred_hid = hidden[-tgt_len:]
|
||||
if self.sample_softmax > 0 and self.training:
|
||||
assert self.tie_weight
|
||||
logit = sample_logits(self.word_emb,
|
||||
self.out_layer.bias, target, pred_hid, self.sampler)
|
||||
logit = sample_logits(self.word_emb, self.out_layer.bias, target, pred_hid, self.sampler)
|
||||
loss = -F.log_softmax(logit, -1)[:, :, 0]
|
||||
else:
|
||||
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
|
||||
@ -1010,7 +1019,7 @@ class TransfoXLPreTrainedModel(nn.Module):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None,
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
|
||||
*inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
@ -1100,7 +1109,7 @@ class TransfoXLPreTrainedModel(nn.Module):
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
|
||||
# load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
||||
model.__class__.__name__, missing_keys))
|
||||
@ -1110,9 +1119,6 @@ class TransfoXLPreTrainedModel(nn.Module):
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
# Add additional embeddings for special tokens if needed
|
||||
if num_special_tokens != config.n_special:
|
||||
model.set_num_special_tokens(num_special_tokens)
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
|
@ -14,15 +14,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Tokenization classes for Transformer XL model.
|
||||
Directly adapted from https://github.com/kimiyoung/transformer-xl.
|
||||
Adapted from https://github.com/kimiyoung/transformer-xl.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import glob
|
||||
import logging
|
||||
import pickle
|
||||
import torch
|
||||
from collections import Counter, OrderedDict
|
||||
|
||||
from .file_utils import cached_path
|
||||
@ -30,16 +29,14 @@ from .file_utils import cached_path
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
|
||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
|
||||
}
|
||||
PRETRAINED_MERGES_ARCHIVE_MAP = {
|
||||
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
|
||||
VOCAB_NAME = 'vocab.bin'
|
||||
|
||||
PRETRAINED_CORPUS_ARCHIVE_MAP = {
|
||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
|
||||
}
|
||||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
'openai-gpt': 512,
|
||||
}
|
||||
VOCAB_NAME = 'vocab.json'
|
||||
MERGES_NAME = 'merges.txt'
|
||||
CORPUS_NAME = 'corpus.bin'
|
||||
|
||||
class TransfoXLTokenizer(object):
|
||||
"""
|
||||
@ -49,43 +46,36 @@ class TransfoXLTokenizer(object):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a TransfoXLTokenizer.
|
||||
Download and cache the vocabulary if needed.
|
||||
The TransfoXLTokenizer.
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} "
|
||||
"at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||
pretrained_model_name_or_path,
|
||||
vocab_file, merges_file))
|
||||
vocab_file))
|
||||
return None
|
||||
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
|
||||
if resolved_vocab_file == vocab_file:
|
||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||
logger.info("loading merges file {}".format(merges_file))
|
||||
else:
|
||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||
vocab_file, resolved_vocab_file))
|
||||
logger.info("loading merges file {} from cache at {}".format(
|
||||
merges_file, resolved_merges_file))
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||
# than the number of positional embeddings
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
||||
tokenizer = cls(*inputs, **kwargs)
|
||||
vocab_dict = torch.load(resolved_vocab_file)
|
||||
for key, value in vocab_dict.items():
|
||||
tokenizer.__dict__[key] = value
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
|
||||
@ -418,10 +408,53 @@ class LMMultiFileIterator(LMShuffledIterator):
|
||||
yield batch
|
||||
|
||||
|
||||
class Corpus(object):
|
||||
def __init__(self, path, dataset, *args, **kwargs):
|
||||
class TransfoXLCorpus(object):
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a pre-processed corpus.
|
||||
"""
|
||||
vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:
|
||||
corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} "
|
||||
"at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||
pretrained_model_name_or_path,
|
||||
corpus_file))
|
||||
return None
|
||||
if resolved_corpus_file == corpus_file:
|
||||
logger.info("loading corpus file {}".format(corpus_file))
|
||||
else:
|
||||
logger.info("loading corpus file {} from cache at {}".format(
|
||||
corpus_file, resolved_corpus_file))
|
||||
|
||||
# Instantiate tokenizer.
|
||||
corpus = cls(*inputs, **kwargs)
|
||||
corpus_dict = torch.load(resolved_corpus_file)
|
||||
for key, value in corpus_dict.items():
|
||||
corpus.__dict__[key] = value
|
||||
corpus.vocab = vocab
|
||||
return corpus
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.vocab = TransfoXLTokenizer(*args, **kwargs)
|
||||
self.dataset = None
|
||||
self.train = None
|
||||
self.valid = None
|
||||
self.test = None
|
||||
|
||||
def build_corpus(self, path, dataset):
|
||||
self.dataset = dataset
|
||||
self.vocab = Vocab(*args, **kwargs)
|
||||
|
||||
if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
|
||||
self.vocab.count_file(os.path.join(path, 'train.txt'))
|
||||
@ -443,20 +476,20 @@ class Corpus(object):
|
||||
os.path.join(path, 'train.txt'), ordered=True)
|
||||
self.valid = self.vocab.encode_file(
|
||||
os.path.join(path, 'valid.txt'), ordered=True)
|
||||
self.test = self.vocab.encode_file(
|
||||
self.test = self.vocab.encode_file(
|
||||
os.path.join(path, 'test.txt'), ordered=True)
|
||||
elif self.dataset in ['enwik8', 'text8']:
|
||||
self.train = self.vocab.encode_file(
|
||||
os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
|
||||
self.valid = self.vocab.encode_file(
|
||||
os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
|
||||
self.test = self.vocab.encode_file(
|
||||
self.test = self.vocab.encode_file(
|
||||
os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
|
||||
elif self.dataset == 'lm1b':
|
||||
self.train = train_paths
|
||||
self.valid = self.vocab.encode_file(
|
||||
os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
|
||||
self.test = self.vocab.encode_file(
|
||||
self.test = self.vocab.encode_file(
|
||||
os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
|
||||
|
||||
def get_iterator(self, split, *args, **kwargs):
|
||||
@ -502,7 +535,7 @@ def get_lm_corpus(datadir, dataset):
|
||||
elif dataset in ['enwik8', 'text8']:
|
||||
pass
|
||||
|
||||
corpus = Corpus(datadir, dataset, **kwargs)
|
||||
corpus = TransfoXLCorpus(datadir, dataset, **kwargs)
|
||||
torch.save(corpus, fn)
|
||||
|
||||
return corpus
|
||||
|
Loading…
Reference in New Issue
Block a user