mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
transposing the inputs of Transformer-XL to have a unified interface
This commit is contained in:
parent
32fea876bb
commit
884ca81d87
16
README.md
16
README.md
@ -603,25 +603,25 @@ Transformer XL use a relative positioning with sinusiodal patterns and adaptive
|
||||
|
||||
This model takes as *inputs*:
|
||||
[`modeling_transfo_xl.py`](./pytorch_pretrained_bert/modeling_transfo_xl.py)
|
||||
- `input_ids`: a torch.LongTensor of shape [sequence_length, batch_size] with the token indices selected in the range [0, self.config.n_token[
|
||||
- `mems`: an optional memory of hidden states from previous forward passes as a list (num layers) of hidden states at the entry of each layer. Each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
|
||||
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the token indices selected in the range [0, self.config.n_token[
|
||||
- `mems`: an optional memory of hidden states from previous forward passes as a list (num layers) of hidden states at the entry of each layer. Each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`.
|
||||
|
||||
This model *outputs* a tuple of (last_hidden_state, new_mems)
|
||||
- `last_hidden_state`: the encoded-hidden-states at the top of the model as a torch.FloatTensor of size [sequence_length, batch_size, self.config.d_model]
|
||||
- `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
|
||||
- `last_hidden_state`: the encoded-hidden-states at the top of the model as a torch.FloatTensor of size [batch_size, sequence_length, self.config.d_model]
|
||||
- `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`.
|
||||
|
||||
#### 13. `TransfoXLLMHeadModel`
|
||||
|
||||
`TransfoXLLMHeadModel` includes the `TransfoXLModel` Transformer followed by an (adaptive) softmax head with weights tied to the input embeddings.
|
||||
|
||||
*Inputs* are the same as the inputs of the [`TransfoXLModel`](#-12.-`TransfoXLModel`) class plus optional labels:
|
||||
- `target`: an optional torch.LongTensor of shape [sequence_length, batch_size] with the target token indices selected in the range [0, self.config.n_token[
|
||||
- `target`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the target token indices selected in the range [0, self.config.n_token[
|
||||
|
||||
*Outputs* a tuple of (last_hidden_state, new_mems)
|
||||
- `softmax_output`: output of the (adaptive) softmax:
|
||||
- if target is None: Negative log likelihood of shape :: [len, bsz]
|
||||
- else: log probabilities of tokens, shape :: [len, bsz, n_tokens]
|
||||
- `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
|
||||
- if target is None: Negative log likelihood of shape [batch_size, sequence_length]
|
||||
- else: log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
|
||||
- `new_mems`: list (num layers) of updated mem states at the entry of each layer each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]. Note that the first two dimensions are transposed in `mems` with regards to `input_ids`.
|
||||
|
||||
|
||||
### Tokenizers:
|
||||
|
@ -986,17 +986,19 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
config: a TransfoXLConfig class instance with the configuration to build a new model
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the token indices selected in the range [0, self.config.n_token[
|
||||
`mems`: optional memomry of hidden states from previous forward passes
|
||||
as a list (num layers) of hidden states at the entry of each layer
|
||||
each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
|
||||
Outputs:
|
||||
A tuple of (last_hidden_state, new_mems)
|
||||
`last_hidden_state`: the encoded-hidden-states at the top of the model
|
||||
as a torch.FloatTensor of size [sequence_length, batch_size, self.config.d_model]
|
||||
as a torch.FloatTensor of size [batch_size, sequence_length, self.config.d_model]
|
||||
`new_mems`: list (num layers) of updated mem states at the entry of each layer
|
||||
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
@ -1225,20 +1227,28 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
|
||||
def forward(self, input_ids, mems=None):
|
||||
""" Params:
|
||||
input_ids :: [len, bsz]
|
||||
input_ids :: [bsz, len]
|
||||
mems :: optional mems from previous forwar passes (or init_mems)
|
||||
list (num layers) of mem states at the entry of each layer
|
||||
shape :: [self.config.mem_len, bsz, self.config.d_model]
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
|
||||
Returns:
|
||||
tuple (last_hidden, new_mems) where:
|
||||
new_mems: list (num layers) of mem states at the entry of each layer
|
||||
shape :: [self.config.mem_len, bsz, self.config.d_model]
|
||||
last_hidden: output of the last layer:
|
||||
shape :: [len, bsz, self.config.d_model]
|
||||
shape :: [bsz, len, self.config.d_model]
|
||||
"""
|
||||
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
|
||||
# so we transpose here from shape [bsz, len] to shape [len, bsz]
|
||||
input_ids = input_ids.transpose(0, 1).contiguous()
|
||||
|
||||
if mems is None:
|
||||
mems = self.init_mems(input_ids)
|
||||
last_hidden, new_mems = self._forward(input_ids, mems=mems)
|
||||
|
||||
# We transpose back here to shape [bsz, len, hidden_dim]
|
||||
last_hidden = last_hidden.transpose(0, 1).contiguous()
|
||||
return (last_hidden, new_mems)
|
||||
|
||||
|
||||
@ -1257,23 +1267,25 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
config: a TransfoXLConfig class instance with the configuration to build a new model
|
||||
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [sequence_length, batch_size]
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the token indices selected in the range [0, self.config.n_token[
|
||||
`target`: an optional torch.LongTensor of shape [sequence_length, batch_size]
|
||||
`target`: an optional torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the target token indices selected in the range [0, self.config.n_token[
|
||||
`mems`: an optional memory of hidden states from previous forward passes
|
||||
as a list (num layers) of hidden states at the entry of each layer
|
||||
each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
|
||||
|
||||
Outputs:
|
||||
A tuple of (last_hidden_state, new_mems)
|
||||
`softmax_output`: output of the (adaptive) softmax:
|
||||
if target is None:
|
||||
Negative log likelihood of shape :: [len, bsz]
|
||||
Negative log likelihood of shape [batch_size, sequence_length]
|
||||
else:
|
||||
log probabilities of tokens, shape :: [len, bsz, n_tokens]
|
||||
log probabilities of tokens, shape [batch_size, sequence_length, n_tokens]
|
||||
`new_mems`: list (num layers) of updated mem states at the entry of each layer
|
||||
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
|
||||
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target`
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
@ -1287,7 +1299,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
last_hidden_state, new_mems = model(input_ids)
|
||||
|
||||
# Another time on input_ids_next using the memory:
|
||||
last_hidden_state, new_mems = model(input_ids_next, new_mems)
|
||||
last_hidden_state, new_mems = model(input_ids_next, mems=new_mems)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
@ -1331,33 +1343,34 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
|
||||
def forward(self, input_ids, target=None, mems=None):
|
||||
""" Params:
|
||||
input_ids :: [len, bsz]
|
||||
target :: [len, bsz]
|
||||
input_ids :: [bsz, len]
|
||||
target :: [bsz, len]
|
||||
Returns:
|
||||
tuple(softmax_output, new_mems) where:
|
||||
new_mems: list (num layers) of hidden states at the entry of each layer
|
||||
shape :: [mem_len, bsz, self.config.d_model]
|
||||
shape :: [mem_len, bsz, self.config.d_model] :: Warning: shapes are transposed here w. regards to input_ids
|
||||
softmax_output: output of the (adaptive) softmax:
|
||||
if target is None:
|
||||
Negative log likelihood of shape :: [len, bsz]
|
||||
Negative log likelihood of shape :: [bsz, len]
|
||||
else:
|
||||
log probabilities of tokens, shape :: [len, bsz, n_tokens]
|
||||
log probabilities of tokens, shape :: [bsz, len, n_tokens]
|
||||
"""
|
||||
bsz = input_ids.size(1)
|
||||
tgt_len = input_ids.size(0)
|
||||
bsz = input_ids.size(0)
|
||||
tgt_len = input_ids.size(1)
|
||||
|
||||
last_hidden, new_mems = self.transformer(input_ids, mems)
|
||||
|
||||
pred_hid = last_hidden[-tgt_len:]
|
||||
pred_hid = last_hidden[:, -tgt_len:]
|
||||
if self.sample_softmax > 0 and self.training:
|
||||
assert self.config.tie_weight
|
||||
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, target, pred_hid, self.sampler)
|
||||
loss = -F.log_softmax(logit, -1)[:, :, 0]
|
||||
softmax_output = -F.log_softmax(logit, -1)[:, :, 0]
|
||||
else:
|
||||
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target)
|
||||
if target is None:
|
||||
softmax_output = softmax_output.view(tgt_len, bsz, -1)
|
||||
softmax_output = softmax_output.view(bsz, tgt_len, -1)
|
||||
else:
|
||||
softmax_output = softmax_output.view(tgt_len, bsz)
|
||||
softmax_output = softmax_output.view(bsz, tgt_len)
|
||||
|
||||
# We transpose back
|
||||
return (softmax_output, new_mems)
|
||||
|
@ -507,7 +507,7 @@ class TransfoXLCorpus(object):
|
||||
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
|
||||
except EnvironmentError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"Corpus '{}' was not found in corpus list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} "
|
||||
"at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
|
@ -67,12 +67,12 @@ class TransfoXLModelTest(unittest.TestCase):
|
||||
self.seed = seed
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = TransfoXLModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
|
||||
input_ids_2 = TransfoXLModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
|
||||
input_ids_1 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_ids_2 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
lm_labels = None
|
||||
if self.use_labels:
|
||||
lm_labels = TransfoXLModelTest.ids_tensor([self.seq_length, self.batch_size], self.vocab_size)
|
||||
lm_labels = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
config = TransfoXLConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
@ -110,13 +110,13 @@ class TransfoXLModelTest(unittest.TestCase):
|
||||
def check_transfo_xl_model_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_1"].size()),
|
||||
[self.seq_length, self.batch_size, self.d_model])
|
||||
[self.batch_size, self.seq_length, self.d_model])
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_2"].size()),
|
||||
[self.batch_size, self.seq_length, self.d_model])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_2"].size()),
|
||||
[self.seq_length, self.batch_size, self.d_model])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
@ -147,13 +147,13 @@ class TransfoXLModelTest(unittest.TestCase):
|
||||
def check_transfo_xl_lm_head_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss_1"].size()),
|
||||
[self.seq_length, self.batch_size])
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1a"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].size()),
|
||||
[self.seq_length, self.batch_size, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1b"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
@ -163,13 +163,13 @@ class TransfoXLModelTest(unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss_2"].size()),
|
||||
[self.seq_length, self.batch_size])
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2a"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].size()),
|
||||
[self.seq_length, self.batch_size, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2b"]),
|
||||
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
|
||||
|
Loading…
Reference in New Issue
Block a user