add two transformer xl models

This commit is contained in:
thomwolf 2019-02-07 17:07:03 +01:00
parent d482e3d79d
commit c306869ea2
3 changed files with 174 additions and 64 deletions

View File

@ -11,7 +11,7 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining,
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
load_tf_weights_in_openai_gpt) load_tf_weights_in_openai_gpt)
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl) load_tf_weights_in_transfo_xl)
from .optimization import BertAdam from .optimization import BertAdam

View File

@ -27,7 +27,7 @@ import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
TransfoXLConfig, TransfoXLConfig,
TransfoXLModel, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl) load_tf_weights_in_transfo_xl)
from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME,
VOCAB_NAME) VOCAB_NAME)
@ -37,7 +37,7 @@ if sys.version_info[0] == 2:
else: else:
import pickle import pickle
# We do this to be able to load the python 2 datasets pickles # We do this to be able to load python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
data_utils.Vocab = data_utils.TransfoXLTokenizer data_utils.Vocab = data_utils.TransfoXLTokenizer
data_utils.Corpus = data_utils.TransfoXLCorpus data_utils.Corpus = data_utils.TransfoXLCorpus
@ -49,6 +49,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
pytorch_dump_folder_path, pytorch_dump_folder_path,
transfo_xl_dataset_file): transfo_xl_dataset_file):
if transfo_xl_dataset_file: if transfo_xl_dataset_file:
# Convert a pre-processed corpus (see original TensorFlow repo)
with open(transfo_xl_dataset_file, "rb") as fp: with open(transfo_xl_dataset_file, "rb") as fp:
corpus = pickle.load(fp, encoding="latin1") corpus = pickle.load(fp, encoding="latin1")
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
@ -64,18 +65,18 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
if tf_checkpoint_path: if tf_checkpoint_path:
# Convert a pre-trained TensorFlow model
config_path = os.path.abspath(transfo_xl_config_file) config_path = os.path.abspath(transfo_xl_config_file)
tf_path = os.path.abspath(tf_checkpoint_path) tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
# Initialise PyTorch model # Initialise PyTorch model
# Construct model
if transfo_xl_config_file == "": if transfo_xl_config_file == "":
config = TransfoXLConfig() config = TransfoXLConfig()
else: else:
config = TransfoXLConfig(transfo_xl_config_file) config = TransfoXLConfig(transfo_xl_config_file)
print("Building PyTorch model from configuration: {}".format(str(config))) print("Building PyTorch model from configuration: {}".format(str(config)))
model = TransfoXLModel(config) model = TransfoXLLMHeadModel(config)
model = load_tf_weights_in_transfo_xl(model, config, tf_path) model = load_tf_weights_in_transfo_xl(model, config, tf_path)
# Save pytorch-model # Save pytorch-model
@ -90,7 +91,6 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--pytorch_dump_folder_path", parser.add_argument("--pytorch_dump_folder_path",
default = None, default = None,
type = str, type = str,

View File

@ -57,7 +57,7 @@ def build_tf_to_pytorch_map(model, config):
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible. This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
""" """
tf_to_pt_map = {} tf_to_pt_map = {}
# Embeddings cutoffs # Embeddings
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)): for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
tf_to_pt_map.update({ tf_to_pt_map.update({
@ -934,11 +934,11 @@ class TransfoXLPreTrainedModel(nn.Module):
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf: if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file) state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
if from_tf: if from_tf:
# Directly load from a TensorFlow checkpoint # Directly load from a TensorFlow checkpoint
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) return load_tf_weights_in_transfo_xl(model, config, pretrained_model_name_or_path)
return load_tf_weights_in_transfo_xl(model, weights_path)
missing_keys = [] missing_keys = []
unexpected_keys = [] unexpected_keys = []
error_msgs = [] error_msgs = []
@ -965,18 +965,49 @@ class TransfoXLPreTrainedModel(nn.Module):
if len(error_msgs) > 0: if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs))) model.__class__.__name__, "\n\t".join(error_msgs)))
# Make sure we are still sharing the input and output embeddings
if model.hasattr('tie_weights'):
model.tie_weights()
return model return model
class TransfoXLModel(TransfoXLPreTrainedModel): class TransfoXLModel(TransfoXLPreTrainedModel):
"""Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context").
Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that:
- you don't need to specify positioning embeddings indices
- the tokens in the vocabulary have to be sorted to decreasing frequency.
Params:
config: a TransfoXLConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
with the token indices selected in the range [0, self.config.n_token[
Outputs:
A tuple of (last_hidden_state, new_mems)
`last_hidden_state`: the encoded-hidden-states at the top of the model
as a torch.FloatTensor of size [sequence_length, batch_size, self.config.d_model]
`new_mems`: list (num layers) of updated mem states at the entry of each layer
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_ids_next = torch.LongTensor([[53, 21, 1], [64, 23, 100]])
config = TransfoXLConfig()
model = TransfoXLModel(config)
last_hidden_state, new_mems = model(input_ids)
# Another time on input_ids_next using the memory:
last_hidden_state, new_mems = model(input_ids_next, new_mems)
```
"""
def __init__(self, config): def __init__(self, config):
# n_token, n_layer, n_head, d_model, d_head, d_inner,
# dropout, dropatt, tie_weight=True, d_embed=None,
# div_val=1, tie_projs=[False], pre_lnorm=False,
# tgt_len=None, ext_len=None, mem_len=None,
# cutoffs=[], adapt_inp=False, untie_r=False,
# same_length=False, attn_type=0, clamp_len=-1,
# sample_softmax=-1, **kwargs):
super(TransfoXLModel, self).__init__(config) super(TransfoXLModel, self).__init__(config)
self.n_token = config.n_token self.n_token = config.n_token
@ -1034,31 +1065,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
r_r_bias=None if config.untie_r else self.r_r_bias) r_r_bias=None if config.untie_r else self.r_r_bias)
) )
self.sample_softmax = config.sample_softmax
# use sampled softmax
if config.sample_softmax > 0:
self.out_layer = nn.Linear(config.d_model, config.n_token)
if config.tie_weight:
self.out_layer.weight = self.word_emb.weight
self.tie_weight = config.tie_weight
self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
# use adaptive softmax (including standard softmax)
else:
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
config.cutoffs, div_val=config.div_val)
if config.tie_weight:
for i in range(len(self.crit.out_layers)):
self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
if config.tie_projs:
for i, tie_proj in enumerate(config.tie_projs):
if tie_proj and config.div_val == 1 and config.d_model != config.d_embed:
self.crit.out_projs[i] = self.word_emb.emb_projs[0]
elif tie_proj and config.div_val != 1:
self.crit.out_projs[i] = self.word_emb.emb_projs[i]
self.same_length = config.same_length self.same_length = config.same_length
self.clamp_len = config.clamp_len self.clamp_len = config.clamp_len
@ -1074,6 +1080,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
elif self.attn_type == 3: # absolute deeper SA elif self.attn_type == 3: # absolute deeper SA
self.r_emb = nn.Parameter(torch.Tensor( self.r_emb = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head, self.d_head)) self.n_layer, self.max_klen, self.n_head, self.d_head))
self.apply(self.init_weights)
def backward_compatible(self): def backward_compatible(self):
self.sample_softmax = -1 self.sample_softmax = -1
@ -1210,32 +1217,135 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return core_out, new_mems return core_out, new_mems
def forward(self, data, target=None, *mems): def forward(self, input_ids, mems=None):
# nn.DataParallel does not allow size(0) tensors to be broadcasted. """ Params:
# So, have to initialize size(0) mems inside the model forward. input_ids :: [len, bsz]
# Moreover, have to return new_mems to allow nn.DataParallel to piece Returns:
# them together. tuple (last_hidden, new_mems) where:
if not mems: new_mems: list (num layers) of mem states at the entry of each layer
mems = self.init_mems(data) shape :: [self.config.mem_len, bsz, self.config.d_model]
last_hidden: output of the last layer:
shape :: [len, bsz, self.config.d_model]
"""
if mems is None:
mems = self.init_mems(input_ids)
last_hidden, new_mems = self._forward(input_ids, mems=mems)
return (last_hidden, new_mems)
hidden, new_mems = self._forward(data, mems=mems)
if target is None: class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
if new_mems is None: """Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context").
return [hidden]
This model add an (adaptive) softmax head on top of the TransfoXLModel
Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that:
- you don't need to specify positioning embeddings indices
- the tokens in the vocabulary have to be sorted to decreasing frequency.
Call self.tie_weights() if you update/load the weights of the transformer to keep the weights tied.
Params:
config: a TransfoXLConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
with the token indices selected in the range [0, self.config.n_token[
`target`: a torch.LongTensor of shape [sequence_length, batch_size]
with the target token indices selected in the range [0, self.config.n_token[
Outputs:
A tuple of (last_hidden_state, new_mems)
`softmax_output`: output of the (adaptive) softmax:
if target is None:
Negative log likelihood of shape :: [len, bsz]
else: else:
return [hidden] + new_mems log probabilities of tokens, shape :: [len, bsz, n_tokens]
`new_mems`: list (num layers) of updated mem states at the entry of each layer
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
tgt_len = target.size(0) Example usage:
pred_hid = hidden[-tgt_len:] ```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_ids_next = torch.LongTensor([[53, 21, 1], [64, 23, 100]])
config = TransfoXLConfig()
model = TransfoXLModel(config)
last_hidden_state, new_mems = model(input_ids)
# Another time on input_ids_next using the memory:
last_hidden_state, new_mems = model(input_ids_next, new_mems)
```
"""
def __init__(self, config):
super(TransfoXLLMHeadModel, self).__init__(config)
self.transformer = TransfoXLModel(config)
self.sample_softmax = config.sample_softmax
# use sampled softmax
if config.sample_softmax > 0:
self.out_layer = nn.Linear(config.d_model, config.n_token)
self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
# use adaptive softmax (including standard softmax)
else:
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
config.cutoffs, div_val=config.div_val)
self.apply(self.init_weights)
self.tie_weights()
def tie_weights(self):
""" Run this to be sure output and input (adaptive) softmax weights are tied """
# sampled softmax
if self.sample_softmax > 0:
if self.config.tie_weight:
self.out_layer.weight = self.transformer.word_emb.weight
# adaptive softmax (including standard softmax)
else:
if self.config.tie_weight:
for i in range(len(self.crit.out_layers)):
self.crit.out_layers[i].weight = self.transformer.word_emb.emb_layers[i].weight
if self.config.tie_projs:
for i, tie_proj in enumerate(self.config.tie_projs):
if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
elif tie_proj and self.config.div_val != 1:
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
def reset_length(self, tgt_len, ext_len, mem_len):
self.transformer.reset_length(tgt_len, ext_len, mem_len)
def init_mems(self, data):
return self.transformer.init_mems(data)
def forward(self, input_ids, target=None, mems=None):
""" Params:
input_ids :: [len, bsz]
target :: [len, bsz]
Returns:
tuple(softmax_output, new_mems) where:
new_mems: list (num layers) of hidden states at the entry of each layer
shape :: [mem_len, bsz, self.config.d_model]
softmax_output: output of the (adaptive) softmax:
if target is None:
Negative log likelihood of shape :: [len, bsz]
else:
log probabilities of tokens, shape :: [len, bsz, n_tokens]
"""
bsz = input_ids.size(1)
tgt_len = input_ids.size(0)
last_hidden, new_mems = self.transformer(input_ids, mems)
pred_hid = last_hidden[-tgt_len:]
if self.sample_softmax > 0 and self.training: if self.sample_softmax > 0 and self.training:
assert self.tie_weight assert self.config.tie_weight
logit = sample_logits(self.word_emb, self.out_layer.bias, target, pred_hid, self.sampler) logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, target, pred_hid, self.sampler)
loss = -F.log_softmax(logit, -1)[:, :, 0] loss = -F.log_softmax(logit, -1)[:, :, 0]
else: else:
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1)) softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target)
loss = loss.view(tgt_len, -1) if target is None:
softmax_output = softmax_output.view(tgt_len, bsz, -1)
else:
softmax_output = softmax_output.view(tgt_len, bsz)
if new_mems is None: return (softmax_output, new_mems)
return [loss]
else:
return (loss, new_mems)