mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
add two transformer xl models
This commit is contained in:
parent
d482e3d79d
commit
c306869ea2
@ -11,7 +11,7 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
|||||||
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
|
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
|
||||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
||||||
load_tf_weights_in_openai_gpt)
|
load_tf_weights_in_openai_gpt)
|
||||||
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel,
|
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
|
||||||
load_tf_weights_in_transfo_xl)
|
load_tf_weights_in_transfo_xl)
|
||||||
|
|
||||||
from .optimization import BertAdam
|
from .optimization import BertAdam
|
||||||
|
@ -27,7 +27,7 @@ import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
|
|||||||
from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME,
|
from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
TransfoXLConfig,
|
TransfoXLConfig,
|
||||||
TransfoXLModel,
|
TransfoXLLMHeadModel,
|
||||||
load_tf_weights_in_transfo_xl)
|
load_tf_weights_in_transfo_xl)
|
||||||
from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME,
|
from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME,
|
||||||
VOCAB_NAME)
|
VOCAB_NAME)
|
||||||
@ -37,7 +37,7 @@ if sys.version_info[0] == 2:
|
|||||||
else:
|
else:
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
# We do this to be able to load the python 2 datasets pickles
|
# We do this to be able to load python 2 datasets pickles
|
||||||
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
||||||
data_utils.Vocab = data_utils.TransfoXLTokenizer
|
data_utils.Vocab = data_utils.TransfoXLTokenizer
|
||||||
data_utils.Corpus = data_utils.TransfoXLCorpus
|
data_utils.Corpus = data_utils.TransfoXLCorpus
|
||||||
@ -49,6 +49,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
|||||||
pytorch_dump_folder_path,
|
pytorch_dump_folder_path,
|
||||||
transfo_xl_dataset_file):
|
transfo_xl_dataset_file):
|
||||||
if transfo_xl_dataset_file:
|
if transfo_xl_dataset_file:
|
||||||
|
# Convert a pre-processed corpus (see original TensorFlow repo)
|
||||||
with open(transfo_xl_dataset_file, "rb") as fp:
|
with open(transfo_xl_dataset_file, "rb") as fp:
|
||||||
corpus = pickle.load(fp, encoding="latin1")
|
corpus = pickle.load(fp, encoding="latin1")
|
||||||
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
|
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
|
||||||
@ -64,18 +65,18 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
|||||||
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
|
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
|
||||||
|
|
||||||
if tf_checkpoint_path:
|
if tf_checkpoint_path:
|
||||||
|
# Convert a pre-trained TensorFlow model
|
||||||
config_path = os.path.abspath(transfo_xl_config_file)
|
config_path = os.path.abspath(transfo_xl_config_file)
|
||||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||||
|
|
||||||
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
|
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
|
||||||
# Initialise PyTorch model
|
# Initialise PyTorch model
|
||||||
# Construct model
|
|
||||||
if transfo_xl_config_file == "":
|
if transfo_xl_config_file == "":
|
||||||
config = TransfoXLConfig()
|
config = TransfoXLConfig()
|
||||||
else:
|
else:
|
||||||
config = TransfoXLConfig(transfo_xl_config_file)
|
config = TransfoXLConfig(transfo_xl_config_file)
|
||||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||||
model = TransfoXLModel(config)
|
model = TransfoXLLMHeadModel(config)
|
||||||
|
|
||||||
model = load_tf_weights_in_transfo_xl(model, config, tf_path)
|
model = load_tf_weights_in_transfo_xl(model, config, tf_path)
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
@ -90,7 +91,6 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
|
||||||
parser.add_argument("--pytorch_dump_folder_path",
|
parser.add_argument("--pytorch_dump_folder_path",
|
||||||
default = None,
|
default = None,
|
||||||
type = str,
|
type = str,
|
||||||
|
@ -57,7 +57,7 @@ def build_tf_to_pytorch_map(model, config):
|
|||||||
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
|
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
|
||||||
"""
|
"""
|
||||||
tf_to_pt_map = {}
|
tf_to_pt_map = {}
|
||||||
# Embeddings cutoffs
|
# Embeddings
|
||||||
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
|
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
|
||||||
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
|
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
|
||||||
tf_to_pt_map.update({
|
tf_to_pt_map.update({
|
||||||
@ -934,11 +934,11 @@ class TransfoXLPreTrainedModel(nn.Module):
|
|||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
model = cls(config, *inputs, **kwargs)
|
model = cls(config, *inputs, **kwargs)
|
||||||
if state_dict is None and not from_tf:
|
if state_dict is None and not from_tf:
|
||||||
state_dict = torch.load(resolved_archive_file)
|
state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||||
if from_tf:
|
if from_tf:
|
||||||
# Directly load from a TensorFlow checkpoint
|
# Directly load from a TensorFlow checkpoint
|
||||||
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
return load_tf_weights_in_transfo_xl(model, config, pretrained_model_name_or_path)
|
||||||
return load_tf_weights_in_transfo_xl(model, weights_path)
|
|
||||||
missing_keys = []
|
missing_keys = []
|
||||||
unexpected_keys = []
|
unexpected_keys = []
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
@ -965,18 +965,49 @@ class TransfoXLPreTrainedModel(nn.Module):
|
|||||||
if len(error_msgs) > 0:
|
if len(error_msgs) > 0:
|
||||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
|
# Make sure we are still sharing the input and output embeddings
|
||||||
|
if model.hasattr('tie_weights'):
|
||||||
|
model.tie_weights()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class TransfoXLModel(TransfoXLPreTrainedModel):
|
class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||||
|
"""Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context").
|
||||||
|
|
||||||
|
Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that:
|
||||||
|
- you don't need to specify positioning embeddings indices
|
||||||
|
- the tokens in the vocabulary have to be sorted to decreasing frequency.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
config: a TransfoXLConfig class instance with the configuration to build a new model
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
|
||||||
|
with the token indices selected in the range [0, self.config.n_token[
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
A tuple of (last_hidden_state, new_mems)
|
||||||
|
`last_hidden_state`: the encoded-hidden-states at the top of the model
|
||||||
|
as a torch.FloatTensor of size [sequence_length, batch_size, self.config.d_model]
|
||||||
|
`new_mems`: list (num layers) of updated mem states at the entry of each layer
|
||||||
|
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
# Already been converted into BPE token ids
|
||||||
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||||
|
input_ids_next = torch.LongTensor([[53, 21, 1], [64, 23, 100]])
|
||||||
|
|
||||||
|
config = TransfoXLConfig()
|
||||||
|
|
||||||
|
model = TransfoXLModel(config)
|
||||||
|
last_hidden_state, new_mems = model(input_ids)
|
||||||
|
|
||||||
|
# Another time on input_ids_next using the memory:
|
||||||
|
last_hidden_state, new_mems = model(input_ids_next, new_mems)
|
||||||
|
```
|
||||||
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
# n_token, n_layer, n_head, d_model, d_head, d_inner,
|
|
||||||
# dropout, dropatt, tie_weight=True, d_embed=None,
|
|
||||||
# div_val=1, tie_projs=[False], pre_lnorm=False,
|
|
||||||
# tgt_len=None, ext_len=None, mem_len=None,
|
|
||||||
# cutoffs=[], adapt_inp=False, untie_r=False,
|
|
||||||
# same_length=False, attn_type=0, clamp_len=-1,
|
|
||||||
# sample_softmax=-1, **kwargs):
|
|
||||||
super(TransfoXLModel, self).__init__(config)
|
super(TransfoXLModel, self).__init__(config)
|
||||||
self.n_token = config.n_token
|
self.n_token = config.n_token
|
||||||
|
|
||||||
@ -1034,31 +1065,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
r_r_bias=None if config.untie_r else self.r_r_bias)
|
r_r_bias=None if config.untie_r else self.r_r_bias)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.sample_softmax = config.sample_softmax
|
|
||||||
# use sampled softmax
|
|
||||||
if config.sample_softmax > 0:
|
|
||||||
self.out_layer = nn.Linear(config.d_model, config.n_token)
|
|
||||||
if config.tie_weight:
|
|
||||||
self.out_layer.weight = self.word_emb.weight
|
|
||||||
self.tie_weight = config.tie_weight
|
|
||||||
self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
|
|
||||||
|
|
||||||
# use adaptive softmax (including standard softmax)
|
|
||||||
else:
|
|
||||||
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
|
|
||||||
config.cutoffs, div_val=config.div_val)
|
|
||||||
|
|
||||||
if config.tie_weight:
|
|
||||||
for i in range(len(self.crit.out_layers)):
|
|
||||||
self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
|
|
||||||
|
|
||||||
if config.tie_projs:
|
|
||||||
for i, tie_proj in enumerate(config.tie_projs):
|
|
||||||
if tie_proj and config.div_val == 1 and config.d_model != config.d_embed:
|
|
||||||
self.crit.out_projs[i] = self.word_emb.emb_projs[0]
|
|
||||||
elif tie_proj and config.div_val != 1:
|
|
||||||
self.crit.out_projs[i] = self.word_emb.emb_projs[i]
|
|
||||||
|
|
||||||
self.same_length = config.same_length
|
self.same_length = config.same_length
|
||||||
self.clamp_len = config.clamp_len
|
self.clamp_len = config.clamp_len
|
||||||
|
|
||||||
@ -1074,6 +1080,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
elif self.attn_type == 3: # absolute deeper SA
|
elif self.attn_type == 3: # absolute deeper SA
|
||||||
self.r_emb = nn.Parameter(torch.Tensor(
|
self.r_emb = nn.Parameter(torch.Tensor(
|
||||||
self.n_layer, self.max_klen, self.n_head, self.d_head))
|
self.n_layer, self.max_klen, self.n_head, self.d_head))
|
||||||
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def backward_compatible(self):
|
def backward_compatible(self):
|
||||||
self.sample_softmax = -1
|
self.sample_softmax = -1
|
||||||
@ -1210,32 +1217,135 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
|
|
||||||
return core_out, new_mems
|
return core_out, new_mems
|
||||||
|
|
||||||
def forward(self, data, target=None, *mems):
|
def forward(self, input_ids, mems=None):
|
||||||
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
|
""" Params:
|
||||||
# So, have to initialize size(0) mems inside the model forward.
|
input_ids :: [len, bsz]
|
||||||
# Moreover, have to return new_mems to allow nn.DataParallel to piece
|
Returns:
|
||||||
# them together.
|
tuple (last_hidden, new_mems) where:
|
||||||
if not mems:
|
new_mems: list (num layers) of mem states at the entry of each layer
|
||||||
mems = self.init_mems(data)
|
shape :: [self.config.mem_len, bsz, self.config.d_model]
|
||||||
|
last_hidden: output of the last layer:
|
||||||
|
shape :: [len, bsz, self.config.d_model]
|
||||||
|
"""
|
||||||
|
if mems is None:
|
||||||
|
mems = self.init_mems(input_ids)
|
||||||
|
last_hidden, new_mems = self._forward(input_ids, mems=mems)
|
||||||
|
return (last_hidden, new_mems)
|
||||||
|
|
||||||
hidden, new_mems = self._forward(data, mems=mems)
|
|
||||||
if target is None:
|
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||||
if new_mems is None:
|
"""Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context").
|
||||||
return [hidden]
|
|
||||||
|
This model add an (adaptive) softmax head on top of the TransfoXLModel
|
||||||
|
|
||||||
|
Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that:
|
||||||
|
- you don't need to specify positioning embeddings indices
|
||||||
|
- the tokens in the vocabulary have to be sorted to decreasing frequency.
|
||||||
|
|
||||||
|
Call self.tie_weights() if you update/load the weights of the transformer to keep the weights tied.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
config: a TransfoXLConfig class instance with the configuration to build a new model
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
|
||||||
|
with the token indices selected in the range [0, self.config.n_token[
|
||||||
|
`target`: a torch.LongTensor of shape [sequence_length, batch_size]
|
||||||
|
with the target token indices selected in the range [0, self.config.n_token[
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
A tuple of (last_hidden_state, new_mems)
|
||||||
|
`softmax_output`: output of the (adaptive) softmax:
|
||||||
|
if target is None:
|
||||||
|
Negative log likelihood of shape :: [len, bsz]
|
||||||
else:
|
else:
|
||||||
return [hidden] + new_mems
|
log probabilities of tokens, shape :: [len, bsz, n_tokens]
|
||||||
|
`new_mems`: list (num layers) of updated mem states at the entry of each layer
|
||||||
|
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
|
||||||
|
|
||||||
tgt_len = target.size(0)
|
Example usage:
|
||||||
pred_hid = hidden[-tgt_len:]
|
```python
|
||||||
|
# Already been converted into BPE token ids
|
||||||
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||||
|
input_ids_next = torch.LongTensor([[53, 21, 1], [64, 23, 100]])
|
||||||
|
|
||||||
|
config = TransfoXLConfig()
|
||||||
|
|
||||||
|
model = TransfoXLModel(config)
|
||||||
|
last_hidden_state, new_mems = model(input_ids)
|
||||||
|
|
||||||
|
# Another time on input_ids_next using the memory:
|
||||||
|
last_hidden_state, new_mems = model(input_ids_next, new_mems)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
super(TransfoXLLMHeadModel, self).__init__(config)
|
||||||
|
self.transformer = TransfoXLModel(config)
|
||||||
|
self.sample_softmax = config.sample_softmax
|
||||||
|
# use sampled softmax
|
||||||
|
if config.sample_softmax > 0:
|
||||||
|
self.out_layer = nn.Linear(config.d_model, config.n_token)
|
||||||
|
self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
|
||||||
|
# use adaptive softmax (including standard softmax)
|
||||||
|
else:
|
||||||
|
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
|
||||||
|
config.cutoffs, div_val=config.div_val)
|
||||||
|
self.apply(self.init_weights)
|
||||||
|
self.tie_weights()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
""" Run this to be sure output and input (adaptive) softmax weights are tied """
|
||||||
|
# sampled softmax
|
||||||
|
if self.sample_softmax > 0:
|
||||||
|
if self.config.tie_weight:
|
||||||
|
self.out_layer.weight = self.transformer.word_emb.weight
|
||||||
|
# adaptive softmax (including standard softmax)
|
||||||
|
else:
|
||||||
|
if self.config.tie_weight:
|
||||||
|
for i in range(len(self.crit.out_layers)):
|
||||||
|
self.crit.out_layers[i].weight = self.transformer.word_emb.emb_layers[i].weight
|
||||||
|
if self.config.tie_projs:
|
||||||
|
for i, tie_proj in enumerate(self.config.tie_projs):
|
||||||
|
if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
|
||||||
|
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
|
||||||
|
elif tie_proj and self.config.div_val != 1:
|
||||||
|
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
|
||||||
|
|
||||||
|
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||||
|
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
||||||
|
|
||||||
|
def init_mems(self, data):
|
||||||
|
return self.transformer.init_mems(data)
|
||||||
|
|
||||||
|
def forward(self, input_ids, target=None, mems=None):
|
||||||
|
""" Params:
|
||||||
|
input_ids :: [len, bsz]
|
||||||
|
target :: [len, bsz]
|
||||||
|
Returns:
|
||||||
|
tuple(softmax_output, new_mems) where:
|
||||||
|
new_mems: list (num layers) of hidden states at the entry of each layer
|
||||||
|
shape :: [mem_len, bsz, self.config.d_model]
|
||||||
|
softmax_output: output of the (adaptive) softmax:
|
||||||
|
if target is None:
|
||||||
|
Negative log likelihood of shape :: [len, bsz]
|
||||||
|
else:
|
||||||
|
log probabilities of tokens, shape :: [len, bsz, n_tokens]
|
||||||
|
"""
|
||||||
|
bsz = input_ids.size(1)
|
||||||
|
tgt_len = input_ids.size(0)
|
||||||
|
|
||||||
|
last_hidden, new_mems = self.transformer(input_ids, mems)
|
||||||
|
|
||||||
|
pred_hid = last_hidden[-tgt_len:]
|
||||||
if self.sample_softmax > 0 and self.training:
|
if self.sample_softmax > 0 and self.training:
|
||||||
assert self.tie_weight
|
assert self.config.tie_weight
|
||||||
logit = sample_logits(self.word_emb, self.out_layer.bias, target, pred_hid, self.sampler)
|
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, target, pred_hid, self.sampler)
|
||||||
loss = -F.log_softmax(logit, -1)[:, :, 0]
|
loss = -F.log_softmax(logit, -1)[:, :, 0]
|
||||||
else:
|
else:
|
||||||
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
|
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target)
|
||||||
loss = loss.view(tgt_len, -1)
|
if target is None:
|
||||||
|
softmax_output = softmax_output.view(tgt_len, bsz, -1)
|
||||||
|
else:
|
||||||
|
softmax_output = softmax_output.view(tgt_len, bsz)
|
||||||
|
|
||||||
if new_mems is None:
|
return (softmax_output, new_mems)
|
||||||
return [loss]
|
|
||||||
else:
|
|
||||||
return (loss, new_mems)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user