mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 09:42:22 +06:00
improved corpus and tokenization conversion - added evaluation script
This commit is contained in:
parent
7d03c53718
commit
a69ec2c722
151
examples/eval_transfo_xl.py
Normal file
151
examples/eval_transfo_xl.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
|
||||||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PyTorch Transformer XL model evaluation script.
|
||||||
|
Adapted from https://github.com/kimiyoung/transformer-xl.
|
||||||
|
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import functools
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus
|
||||||
|
|
||||||
|
def logging(s, log_path, print_=True, log_=True):
|
||||||
|
if print_:
|
||||||
|
print(s)
|
||||||
|
if log_:
|
||||||
|
with open(log_path, 'a+') as f_log:
|
||||||
|
f_log.write(s + '\n')
|
||||||
|
|
||||||
|
def get_logger(log_path, **kwargs):
|
||||||
|
return functools.partial(logging, log_path=log_path, **kwargs)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
|
||||||
|
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
|
||||||
|
# help='location of the data corpus')
|
||||||
|
parser.add_argument('--model_name', type=str, default='transfo-xl-wt103',
|
||||||
|
choices=['transfo-xl-wt103'], #, 'lm1b', 'enwik8', 'text8'],
|
||||||
|
help='pretrained model name')
|
||||||
|
parser.add_argument('--split', type=str, default='all',
|
||||||
|
choices=['all', 'valid', 'test'],
|
||||||
|
help='which split to evaluate')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=10,
|
||||||
|
help='batch size')
|
||||||
|
parser.add_argument('--tgt_len', type=int, default=5,
|
||||||
|
help='number of tokens to predict')
|
||||||
|
parser.add_argument('--ext_len', type=int, default=0,
|
||||||
|
help='length of the extended context')
|
||||||
|
parser.add_argument('--mem_len', type=int, default=0,
|
||||||
|
help='length of the retained previous heads')
|
||||||
|
parser.add_argument('--clamp_len', type=int, default=-1,
|
||||||
|
help='max positional embedding index')
|
||||||
|
parser.add_argument('--cuda', action='store_true',
|
||||||
|
help='use CUDA')
|
||||||
|
parser.add_argument('--work_dir', type=str, required=True,
|
||||||
|
help='path to the work_dir')
|
||||||
|
parser.add_argument('--no_log', action='store_true',
|
||||||
|
help='do not log the eval result')
|
||||||
|
parser.add_argument('--same_length', action='store_true',
|
||||||
|
help='set same length attention with masking')
|
||||||
|
args = parser.parse_args()
|
||||||
|
assert args.ext_len >= 0, 'extended context length must be non-negative'
|
||||||
|
|
||||||
|
device = torch.device("cuda" if args.cuda else "cpu")
|
||||||
|
|
||||||
|
# Get logger
|
||||||
|
logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
|
||||||
|
log_=not args.no_log)
|
||||||
|
|
||||||
|
# Load dataset
|
||||||
|
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
|
||||||
|
ntokens = len(corpus.vocab)
|
||||||
|
|
||||||
|
va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
|
||||||
|
device=device, ext_len=args.ext_len)
|
||||||
|
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
|
||||||
|
device=device, ext_len=args.ext_len)
|
||||||
|
|
||||||
|
# Load the best saved model.
|
||||||
|
# with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
|
||||||
|
# model = torch.load(f)
|
||||||
|
# model.backward_compatible()
|
||||||
|
model = TransfoXLModel.from_pretrained(args.model_name)
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
|
||||||
|
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
|
||||||
|
|
||||||
|
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
|
||||||
|
if args.clamp_len > 0:
|
||||||
|
model.clamp_len = args.clamp_len
|
||||||
|
if args.same_length:
|
||||||
|
model.same_length = True
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Evaluation code
|
||||||
|
###############################################################################
|
||||||
|
def evaluate(eval_iter):
|
||||||
|
# Turn on evaluation mode which disables dropout.
|
||||||
|
model.eval()
|
||||||
|
total_len, total_loss = 0, 0.
|
||||||
|
start_time = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
mems = tuple()
|
||||||
|
for idx, (data, target, seq_len) in enumerate(eval_iter):
|
||||||
|
ret = model(data, target, *mems)
|
||||||
|
loss, mems = ret[0], ret[1:]
|
||||||
|
loss = loss.mean()
|
||||||
|
total_loss += seq_len * loss.item()
|
||||||
|
total_len += seq_len
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
logging('Time : {:.2f}s, {:.2f}ms/segment'.format(
|
||||||
|
total_time, 1000 * total_time / (idx+1)))
|
||||||
|
return total_loss / total_len
|
||||||
|
|
||||||
|
# Run on test data.
|
||||||
|
if args.split == 'all':
|
||||||
|
test_loss = evaluate(te_iter)
|
||||||
|
valid_loss = evaluate(va_iter)
|
||||||
|
elif args.split == 'valid':
|
||||||
|
valid_loss = evaluate(va_iter)
|
||||||
|
test_loss = None
|
||||||
|
elif args.split == 'test':
|
||||||
|
test_loss = evaluate(te_iter)
|
||||||
|
valid_loss = None
|
||||||
|
|
||||||
|
def format_log(loss, split):
|
||||||
|
if args.dataset in ['enwik8', 'text8']:
|
||||||
|
log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
|
||||||
|
split, loss, loss / math.log(2))
|
||||||
|
else:
|
||||||
|
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
|
||||||
|
split, loss, math.exp(loss))
|
||||||
|
return log_str
|
||||||
|
|
||||||
|
log_str = ''
|
||||||
|
if valid_loss is not None:
|
||||||
|
log_str += format_log(valid_loss, 'valid')
|
||||||
|
if test_loss is not None:
|
||||||
|
log_str += format_log(test_loss, 'test')
|
||||||
|
|
||||||
|
logging('=' * 100)
|
||||||
|
logging(log_str)
|
||||||
|
logging('=' * 100)
|
@ -1,12 +1,14 @@
|
|||||||
__version__ = "0.5.0"
|
__version__ = "0.5.0"
|
||||||
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||||
from .tokenization_openai import OpenAIGPTTokenizer
|
from .tokenization_openai import OpenAIGPTTokenizer
|
||||||
|
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
|
||||||
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||||
BertForMaskedLM, BertForNextSentencePrediction,
|
BertForMaskedLM, BertForNextSentencePrediction,
|
||||||
BertForSequenceClassification, BertForMultipleChoice,
|
BertForSequenceClassification, BertForMultipleChoice,
|
||||||
BertForTokenClassification, BertForQuestionAnswering)
|
BertForTokenClassification, BertForQuestionAnswering)
|
||||||
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
|
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
|
||||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
|
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
|
||||||
|
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel)
|
||||||
from .optimization import BertAdam
|
from .optimization import BertAdam
|
||||||
from .optimization_openai import OpenAIAdam
|
from .optimization_openai import OpenAIAdam
|
||||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
|
@ -12,23 +12,36 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Convert OpenAI GPT checkpoint."""
|
"""Convert Transformer XL checkpoint and datasets."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
|
import pickle
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME
|
from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME
|
||||||
|
from pytorch_pretrained_bert.tokenization_transfo_xl import VOCAB_NAME, CORPUS_NAME
|
||||||
|
|
||||||
|
# We do this to be able to load the python 2 datasets pickles
|
||||||
|
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
||||||
|
import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
|
||||||
|
data_utils.Vocab = data_utils.TransfoXLTokenizer
|
||||||
|
data_utils.Corpus = data_utils.TransfoXLCorpus
|
||||||
|
sys.modules['data_utils'] = data_utils
|
||||||
|
sys.modules['vocabulary'] = data_utils
|
||||||
|
|
||||||
def build_tf_to_pytorch_map(model, config):
|
def build_tf_to_pytorch_map(model, config):
|
||||||
""" A map of modules from TF to PyTorch """
|
""" A map of modules from TF to PyTorch.
|
||||||
|
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
|
||||||
|
"""
|
||||||
tf_to_pt_map = {}
|
tf_to_pt_map = {}
|
||||||
# Embeddings cutoffs
|
# Embeddings cutoffs
|
||||||
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
|
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
|
||||||
@ -95,88 +108,108 @@ def build_tf_to_pytorch_map(model, config):
|
|||||||
|
|
||||||
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||||
transfo_xl_config_file,
|
transfo_xl_config_file,
|
||||||
pytorch_dump_folder_path):
|
pytorch_dump_folder_path,
|
||||||
config_path = os.path.abspath(transfo_xl_config_file)
|
transfo_xl_dataset_file):
|
||||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
if transfo_xl_dataset_file:
|
||||||
|
with open(transfo_xl_dataset_file, "rb") as fp:
|
||||||
|
corpus = pickle.load(fp, encoding="latin1")
|
||||||
|
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
|
||||||
|
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
|
||||||
|
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
|
||||||
|
torch.save(corpus.vocab.__dict__, pytorch_vocab_dump_path)
|
||||||
|
|
||||||
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
|
corpus_dict_no_vocab = corpus.__dict__
|
||||||
# Initialise PyTorch model
|
corpus_dict_no_vocab.pop('vocab', None)
|
||||||
# Construct model
|
pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
|
||||||
if transfo_xl_config_file == "":
|
print("Save dataset to {}".format(pytorch_dataset_dump_path))
|
||||||
config = TransfoXLConfig()
|
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
|
||||||
else:
|
|
||||||
config = TransfoXLConfig(transfo_xl_config_file)
|
|
||||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
|
||||||
model = TransfoXLModel(config)
|
|
||||||
|
|
||||||
# Build TF to PyTorch weights loading map
|
if tf_checkpoint_path:
|
||||||
tf_to_pt_map = build_tf_to_pytorch_map(model.transformer, config)
|
config_path = os.path.abspath(transfo_xl_config_file)
|
||||||
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||||
|
|
||||||
# Load weights from TF model
|
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
|
||||||
init_vars = tf.train.list_variables(tf_path)
|
# Initialise PyTorch model
|
||||||
tf_weights = {}
|
# Construct model
|
||||||
for name, shape in init_vars:
|
if transfo_xl_config_file == "":
|
||||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
config = TransfoXLConfig()
|
||||||
array = tf.train.load_variable(tf_path, name)
|
else:
|
||||||
tf_weights[name] = array
|
config = TransfoXLConfig(transfo_xl_config_file)
|
||||||
|
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||||
|
model = TransfoXLModel(config)
|
||||||
|
|
||||||
for name, pointer in tf_to_pt_map.items():
|
# Build TF to PyTorch weights loading map
|
||||||
assert name in tf_weights
|
tf_to_pt_map = build_tf_to_pytorch_map(model.transformer, config)
|
||||||
array = tf_weights[name]
|
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
|
||||||
# which are not required for using pretrained model
|
|
||||||
if 'kernel' in name or 'proj_W' in name:
|
|
||||||
array = np.transpose(array)
|
|
||||||
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
|
|
||||||
# Here we will split the TF weigths
|
|
||||||
assert len(pointer) == array.shape[0]
|
|
||||||
for i, p_i in enumerate(pointer):
|
|
||||||
arr_i = array[i, ...]
|
|
||||||
try:
|
|
||||||
assert p_i.shape == arr_i.shape
|
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (p_i.shape, arr_i.shape)
|
|
||||||
raise
|
|
||||||
print("Initialize PyTorch weight {} for layer {}".format(name, i))
|
|
||||||
p_i.data = torch.from_numpy(arr_i)
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
assert pointer.shape == array.shape
|
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (pointer.shape, array.shape)
|
|
||||||
raise
|
|
||||||
print("Initialize PyTorch weight {}".format(name))
|
|
||||||
pointer.data = torch.from_numpy(array)
|
|
||||||
|
|
||||||
# Save pytorch-model
|
# Load weights from TF model
|
||||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
init_vars = tf.train.list_variables(tf_path)
|
||||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
tf_weights = {}
|
||||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
for name, shape in init_vars:
|
||||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
array = tf.train.load_variable(tf_path, name)
|
||||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
tf_weights[name] = array
|
||||||
f.write(config.to_json_string())
|
|
||||||
|
for name, pointer in tf_to_pt_map.items():
|
||||||
|
assert name in tf_weights
|
||||||
|
array = tf_weights[name]
|
||||||
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
|
# which are not required for using pretrained model
|
||||||
|
if 'kernel' in name or 'proj_W' in name:
|
||||||
|
array = np.transpose(array)
|
||||||
|
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
|
||||||
|
# Here we will split the TF weigths
|
||||||
|
assert len(pointer) == array.shape[0]
|
||||||
|
for i, p_i in enumerate(pointer):
|
||||||
|
arr_i = array[i, ...]
|
||||||
|
try:
|
||||||
|
assert p_i.shape == arr_i.shape
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (p_i.shape, arr_i.shape)
|
||||||
|
raise
|
||||||
|
print("Initialize PyTorch weight {} for layer {}".format(name, i))
|
||||||
|
p_i.data = torch.from_numpy(arr_i)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
assert pointer.shape == array.shape
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (pointer.shape, array.shape)
|
||||||
|
raise
|
||||||
|
print("Initialize PyTorch weight {}".format(name))
|
||||||
|
pointer.data = torch.from_numpy(array)
|
||||||
|
|
||||||
|
# Save pytorch-model
|
||||||
|
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||||
|
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||||
|
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||||
|
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||||
|
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||||
|
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(config.to_json_string())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--tf_checkpoint_path",
|
|
||||||
default = None,
|
|
||||||
type = str,
|
|
||||||
required = True,
|
|
||||||
help = "Path the TensorFlow checkpoint path.")
|
|
||||||
parser.add_argument("--pytorch_dump_folder_path",
|
parser.add_argument("--pytorch_dump_folder_path",
|
||||||
default = None,
|
default = None,
|
||||||
type = str,
|
type = str,
|
||||||
required = True,
|
required = True,
|
||||||
help = "Path to the output PyTorch model.")
|
help = "Path to the folder to store the PyTorch model or dataset/vocab.")
|
||||||
|
parser.add_argument("--tf_checkpoint_path",
|
||||||
|
default = "",
|
||||||
|
type = str,
|
||||||
|
help = "An optional path to a TensorFlow checkpoint path to be converted.")
|
||||||
parser.add_argument("--transfo_xl_config_file",
|
parser.add_argument("--transfo_xl_config_file",
|
||||||
default = "",
|
default = "",
|
||||||
type = str,
|
type = str,
|
||||||
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
help = "An optional config json file corresponding to the pre-trained BERT model. \n"
|
||||||
"This specifies the model architecture.")
|
"This specifies the model architecture.")
|
||||||
|
parser.add_argument("--transfo_xl_dataset_file",
|
||||||
|
default = "",
|
||||||
|
type = str,
|
||||||
|
help = "An optional dataset file to be converted in a vocabulary.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||||
args.transfo_xl_config_file,
|
args.transfo_xl_config_file,
|
||||||
args.pytorch_dump_folder_path)
|
args.pytorch_dump_folder_path,
|
||||||
|
args.transfo_xl_dataset_file)
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch Transformer XL model.
|
""" PyTorch Transformer XL model.
|
||||||
Directly adapted from https://github.com/kimiyoung/transformer-xl.
|
Adapted from https://github.com/kimiyoung/transformer-xl.
|
||||||
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
|
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ from .file_utils import cached_path
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||||
'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl.tar.gz",
|
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103.tar.gz",
|
||||||
}
|
}
|
||||||
CONFIG_NAME = 'transfo_xl_config.json'
|
CONFIG_NAME = 'transfo_xl_config.json'
|
||||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||||
@ -59,12 +59,13 @@ class TransfoXLConfig(object):
|
|||||||
div_val=4,
|
div_val=4,
|
||||||
pre_lnorm=False,
|
pre_lnorm=False,
|
||||||
n_layer=18,
|
n_layer=18,
|
||||||
tgt_len=256,
|
tgt_len=128,
|
||||||
ext_len=0,
|
ext_len=0,
|
||||||
mem_len=256,
|
mem_len=1600,
|
||||||
same_length=False,
|
clamp_len=1000,
|
||||||
|
same_length=True,
|
||||||
|
proj_share_all_but_first=True,
|
||||||
attn_type=0,
|
attn_type=0,
|
||||||
clamp_len=-1,
|
|
||||||
sample_softmax=-1,
|
sample_softmax=-1,
|
||||||
adaptive=True,
|
adaptive=True,
|
||||||
tie_weight=True,
|
tie_weight=True,
|
||||||
@ -93,6 +94,7 @@ class TransfoXLConfig(object):
|
|||||||
ext_len: length of the extended context
|
ext_len: length of the extended context
|
||||||
mem_len: length of the retained previous heads
|
mem_len: length of the retained previous heads
|
||||||
same_length: use the same attn length for all tokens
|
same_length: use the same attn length for all tokens
|
||||||
|
proj_share_all_but_first: True to share all but first projs, False not to share.
|
||||||
attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.
|
attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.
|
||||||
clamp_len: use the same pos embeddings after clamp_len
|
clamp_len: use the same pos embeddings after clamp_len
|
||||||
sample_softmax: number of samples in sampled softmax
|
sample_softmax: number of samples in sampled softmax
|
||||||
@ -118,7 +120,10 @@ class TransfoXLConfig(object):
|
|||||||
self.cutoffs = []
|
self.cutoffs = []
|
||||||
self.cutoffs.extend(cutoffs)
|
self.cutoffs.extend(cutoffs)
|
||||||
self.tie_weight = tie_weight
|
self.tie_weight = tie_weight
|
||||||
self.tie_projs = [False] + [True] * len(self.cutoffs)
|
if proj_share_all_but_first:
|
||||||
|
self.tie_projs = [False] + [True] * len(self.cutoffs)
|
||||||
|
else:
|
||||||
|
self.tie_projs = [False] + [False] * len(self.cutoffs)
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.d_embed = d_embed
|
self.d_embed = d_embed
|
||||||
self.d_head = d_head
|
self.d_head = d_head
|
||||||
@ -423,7 +428,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|||||||
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
|
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
|
||||||
|
|
||||||
#### compute attention score
|
#### compute attention score
|
||||||
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
|
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
|
||||||
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
|
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
|
||||||
|
|
||||||
rr_head_q = w_head_q + self.r_r_bias
|
rr_head_q = w_head_q + self.r_r_bias
|
||||||
@ -915,21 +920,25 @@ class MemTransformerLM(nn.Module):
|
|||||||
|
|
||||||
return core_out, new_mems
|
return core_out, new_mems
|
||||||
|
|
||||||
def forward(self, data, target, *mems):
|
def forward(self, data, target=None, *mems):
|
||||||
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
|
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
|
||||||
# So, have to initialize size(0) mems inside the model forward.
|
# So, have to initialize size(0) mems inside the model forward.
|
||||||
# Moreover, have to return new_mems to allow nn.DataParallel to piece
|
# Moreover, have to return new_mems to allow nn.DataParallel to piece
|
||||||
# them together.
|
# them together.
|
||||||
if not mems: mems = self.init_mems()
|
if not mems: mems = self.init_mems()
|
||||||
|
|
||||||
tgt_len = target.size(0)
|
|
||||||
hidden, new_mems = self._forward(data, mems=mems)
|
hidden, new_mems = self._forward(data, mems=mems)
|
||||||
|
if target is None:
|
||||||
|
if new_mems is None:
|
||||||
|
return [hidden]
|
||||||
|
else:
|
||||||
|
return [hidden] + new_mems
|
||||||
|
|
||||||
|
tgt_len = target.size(0)
|
||||||
pred_hid = hidden[-tgt_len:]
|
pred_hid = hidden[-tgt_len:]
|
||||||
if self.sample_softmax > 0 and self.training:
|
if self.sample_softmax > 0 and self.training:
|
||||||
assert self.tie_weight
|
assert self.tie_weight
|
||||||
logit = sample_logits(self.word_emb,
|
logit = sample_logits(self.word_emb, self.out_layer.bias, target, pred_hid, self.sampler)
|
||||||
self.out_layer.bias, target, pred_hid, self.sampler)
|
|
||||||
loss = -F.log_softmax(logit, -1)[:, :, 0]
|
loss = -F.log_softmax(logit, -1)[:, :, 0]
|
||||||
else:
|
else:
|
||||||
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
|
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
|
||||||
@ -1010,7 +1019,7 @@ class TransfoXLPreTrainedModel(nn.Module):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None,
|
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
|
||||||
*inputs, **kwargs):
|
*inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||||
@ -1100,7 +1109,7 @@ class TransfoXLPreTrainedModel(nn.Module):
|
|||||||
for name, child in module._modules.items():
|
for name, child in module._modules.items():
|
||||||
if child is not None:
|
if child is not None:
|
||||||
load(child, prefix + name + '.')
|
load(child, prefix + name + '.')
|
||||||
load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
|
# load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
|
||||||
if len(missing_keys) > 0:
|
if len(missing_keys) > 0:
|
||||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
||||||
model.__class__.__name__, missing_keys))
|
model.__class__.__name__, missing_keys))
|
||||||
@ -1110,9 +1119,6 @@ class TransfoXLPreTrainedModel(nn.Module):
|
|||||||
if len(error_msgs) > 0:
|
if len(error_msgs) > 0:
|
||||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
# Add additional embeddings for special tokens if needed
|
|
||||||
if num_special_tokens != config.n_special:
|
|
||||||
model.set_num_special_tokens(num_special_tokens)
|
|
||||||
if tempdir:
|
if tempdir:
|
||||||
# Clean up temp dir
|
# Clean up temp dir
|
||||||
shutil.rmtree(tempdir)
|
shutil.rmtree(tempdir)
|
||||||
|
@ -14,15 +14,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Tokenization classes for Transformer XL model.
|
""" Tokenization classes for Transformer XL model.
|
||||||
Directly adapted from https://github.com/kimiyoung/transformer-xl.
|
Adapted from https://github.com/kimiyoung/transformer-xl.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import glob
|
||||||
import json
|
|
||||||
from tqdm import tqdm
|
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
|
import torch
|
||||||
from collections import Counter, OrderedDict
|
from collections import Counter, OrderedDict
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
@ -30,16 +29,14 @@ from .file_utils import cached_path
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||||
'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
|
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
|
||||||
}
|
}
|
||||||
PRETRAINED_MERGES_ARCHIVE_MAP = {
|
VOCAB_NAME = 'vocab.bin'
|
||||||
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
|
|
||||||
|
PRETRAINED_CORPUS_ARCHIVE_MAP = {
|
||||||
|
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
|
||||||
}
|
}
|
||||||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
CORPUS_NAME = 'corpus.bin'
|
||||||
'openai-gpt': 512,
|
|
||||||
}
|
|
||||||
VOCAB_NAME = 'vocab.json'
|
|
||||||
MERGES_NAME = 'merges.txt'
|
|
||||||
|
|
||||||
class TransfoXLTokenizer(object):
|
class TransfoXLTokenizer(object):
|
||||||
"""
|
"""
|
||||||
@ -49,43 +46,36 @@ class TransfoXLTokenizer(object):
|
|||||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate a TransfoXLTokenizer.
|
Instantiate a TransfoXLTokenizer.
|
||||||
Download and cache the vocabulary if needed.
|
The TransfoXLTokenizer.
|
||||||
"""
|
"""
|
||||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
|
|
||||||
else:
|
else:
|
||||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||||
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
"We assumed '{}' was a path or url but couldn't find files {} "
|
||||||
"at this path or url.".format(
|
"at this path or url.".format(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
vocab_file, merges_file))
|
vocab_file))
|
||||||
return None
|
return None
|
||||||
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
|
if resolved_vocab_file == vocab_file:
|
||||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||||
logger.info("loading merges file {}".format(merges_file))
|
|
||||||
else:
|
else:
|
||||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||||
vocab_file, resolved_vocab_file))
|
vocab_file, resolved_vocab_file))
|
||||||
logger.info("loading merges file {} from cache at {}".format(
|
|
||||||
merges_file, resolved_merges_file))
|
|
||||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
|
||||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
|
||||||
# than the number of positional embeddings
|
|
||||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
|
||||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
|
||||||
# Instantiate tokenizer.
|
# Instantiate tokenizer.
|
||||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
tokenizer = cls(*inputs, **kwargs)
|
||||||
|
vocab_dict = torch.load(resolved_vocab_file)
|
||||||
|
for key, value in vocab_dict.items():
|
||||||
|
tokenizer.__dict__[key] = value
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
|
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
|
||||||
@ -418,10 +408,53 @@ class LMMultiFileIterator(LMShuffledIterator):
|
|||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
class Corpus(object):
|
class TransfoXLCorpus(object):
|
||||||
def __init__(self, path, dataset, *args, **kwargs):
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||||
|
"""
|
||||||
|
Instantiate a pre-processed corpus.
|
||||||
|
"""
|
||||||
|
vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
|
if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:
|
||||||
|
corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
|
else:
|
||||||
|
corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
|
||||||
|
# redirect to the cache, if necessary
|
||||||
|
try:
|
||||||
|
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(
|
||||||
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
|
"We assumed '{}' was a path or url but couldn't find files {} "
|
||||||
|
"at this path or url.".format(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
corpus_file))
|
||||||
|
return None
|
||||||
|
if resolved_corpus_file == corpus_file:
|
||||||
|
logger.info("loading corpus file {}".format(corpus_file))
|
||||||
|
else:
|
||||||
|
logger.info("loading corpus file {} from cache at {}".format(
|
||||||
|
corpus_file, resolved_corpus_file))
|
||||||
|
|
||||||
|
# Instantiate tokenizer.
|
||||||
|
corpus = cls(*inputs, **kwargs)
|
||||||
|
corpus_dict = torch.load(resolved_corpus_file)
|
||||||
|
for key, value in corpus_dict.items():
|
||||||
|
corpus.__dict__[key] = value
|
||||||
|
corpus.vocab = vocab
|
||||||
|
return corpus
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.vocab = TransfoXLTokenizer(*args, **kwargs)
|
||||||
|
self.dataset = None
|
||||||
|
self.train = None
|
||||||
|
self.valid = None
|
||||||
|
self.test = None
|
||||||
|
|
||||||
|
def build_corpus(self, path, dataset):
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.vocab = Vocab(*args, **kwargs)
|
|
||||||
|
|
||||||
if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
|
if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
|
||||||
self.vocab.count_file(os.path.join(path, 'train.txt'))
|
self.vocab.count_file(os.path.join(path, 'train.txt'))
|
||||||
@ -443,20 +476,20 @@ class Corpus(object):
|
|||||||
os.path.join(path, 'train.txt'), ordered=True)
|
os.path.join(path, 'train.txt'), ordered=True)
|
||||||
self.valid = self.vocab.encode_file(
|
self.valid = self.vocab.encode_file(
|
||||||
os.path.join(path, 'valid.txt'), ordered=True)
|
os.path.join(path, 'valid.txt'), ordered=True)
|
||||||
self.test = self.vocab.encode_file(
|
self.test = self.vocab.encode_file(
|
||||||
os.path.join(path, 'test.txt'), ordered=True)
|
os.path.join(path, 'test.txt'), ordered=True)
|
||||||
elif self.dataset in ['enwik8', 'text8']:
|
elif self.dataset in ['enwik8', 'text8']:
|
||||||
self.train = self.vocab.encode_file(
|
self.train = self.vocab.encode_file(
|
||||||
os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
|
os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
|
||||||
self.valid = self.vocab.encode_file(
|
self.valid = self.vocab.encode_file(
|
||||||
os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
|
os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
|
||||||
self.test = self.vocab.encode_file(
|
self.test = self.vocab.encode_file(
|
||||||
os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
|
os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
|
||||||
elif self.dataset == 'lm1b':
|
elif self.dataset == 'lm1b':
|
||||||
self.train = train_paths
|
self.train = train_paths
|
||||||
self.valid = self.vocab.encode_file(
|
self.valid = self.vocab.encode_file(
|
||||||
os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
|
os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
|
||||||
self.test = self.vocab.encode_file(
|
self.test = self.vocab.encode_file(
|
||||||
os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
|
os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
|
||||||
|
|
||||||
def get_iterator(self, split, *args, **kwargs):
|
def get_iterator(self, split, *args, **kwargs):
|
||||||
@ -502,7 +535,7 @@ def get_lm_corpus(datadir, dataset):
|
|||||||
elif dataset in ['enwik8', 'text8']:
|
elif dataset in ['enwik8', 'text8']:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
corpus = Corpus(datadir, dataset, **kwargs)
|
corpus = TransfoXLCorpus(datadir, dataset, **kwargs)
|
||||||
torch.save(corpus, fn)
|
torch.save(corpus, fn)
|
||||||
|
|
||||||
return corpus
|
return corpus
|
||||||
|
Loading…
Reference in New Issue
Block a user