adding option to desactivate past/memory outputs

This commit is contained in:
thomwolf 2019-10-11 15:47:08 +02:00
parent 2a4fef837a
commit 0f9fc4fbde
8 changed files with 93 additions and 55 deletions

View File

@ -53,7 +53,8 @@ class PretrainedConfig(object):
self.num_labels = kwargs.pop('num_labels', 2)
self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.torchscript = kwargs.pop('torchscript', False)
self.output_past = kwargs.pop('output_past', True) # Not used by all models
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
self.pruned_heads = kwargs.pop('pruned_heads', {})

View File

@ -269,16 +269,16 @@ class CTRLModel(CTRLPreTrainedModel):
def __init__(self, config):
super(CTRLModel, self).__init__(config)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float)
self.output_attentions = config.output_attentions
self.w = nn.Embedding(config.vocab_size, config.n_embd)
self.dropout = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([EncoderLayer(config.n_embd,
config.n_head,
@ -378,7 +378,8 @@ class CTRLModel(CTRLPreTrainedModel):
attention_mask=attention_mask,
head_mask=head_mask[i])
hidden_states, present = outputs[:2]
presents = presents + (present,)
if self.output_past:
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(outputs[2])
@ -388,7 +389,9 @@ class CTRLModel(CTRLPreTrainedModel):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents)
outputs = (hidden_states,)
if self.output_past:
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:

View File

@ -347,6 +347,7 @@ class GPT2Model(GPT2PreTrainedModel):
super(GPT2Model, self).__init__(config)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
@ -440,7 +441,8 @@ class GPT2Model(GPT2PreTrainedModel):
head_mask=head_mask[i])
hidden_states, present = outputs[:2]
presents = presents + (present,)
if self.output_past:
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(outputs[2])
@ -452,7 +454,9 @@ class GPT2Model(GPT2PreTrainedModel):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents)
outputs = (hidden_states,)
if self.output_past:
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
@ -460,7 +464,7 @@ class GPT2Model(GPT2PreTrainedModel):
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,)
return outputs # last hidden state, presents, (all hidden_states), (attentions)
return outputs # last hidden state, (presents), (all hidden_states), (attentions)
@add_start_docstrings("""The GPT2 Model transformer with a language modeling head on top

View File

@ -168,12 +168,14 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super(TFCTRLMainLayer, self).__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
self.output_attentions = config.output_attentions
self.w = TFSharedEmbeddings(config.vocab_size,
config.n_embd,
@ -290,7 +292,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training)
hidden_states, present = outputs[:2]
presents = presents + (present,)
if self.output_past:
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(outputs[2])
@ -300,7 +304,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents)
outputs = (hidden_states,)
if self.output_past:
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:

View File

@ -354,6 +354,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
super(TFXLNetMainLayer, self).__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.output_past = config.output_past
self.mem_len = config.mem_len
self.reuse_len = config.reuse_len
@ -413,16 +414,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def cache_mem(self, curr_out, prev_mem):
"""cache hidden states into memory."""
if self.mem_len is None or self.mem_len == 0:
return None
else:
if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[:self.reuse_len]
if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[:self.reuse_len]
if prev_mem is None:
new_mem = curr_out[-self.mem_len:]
else:
new_mem = tf.concat([prev_mem, curr_out], 0)[-self.mem_len:]
if prev_mem is None:
new_mem = curr_out[-self.mem_len:]
else:
new_mem = tf.concat([prev_mem, curr_out], 0)[-self.mem_len:]
return tf.stop_gradient(new_mem)
@ -538,8 +536,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
raise ValueError('Unsupported attention type: {}'.format(self.attn_type))
# data mask: input mask & perm mask
assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) "
"or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one."
assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " \
"or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one."
if input_mask is None and attention_mask is not None:
input_mask = 1.0 - attention_mask
if input_mask is not None and perm_mask is not None:
@ -624,7 +622,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
hidden_states = []
for i, layer_module in enumerate(self.layer):
# cache new mems
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
@ -642,7 +641,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
output = self.dropout(output_g if output_g is not None else output_h, training=training)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs = (tf.transpose(output, perm=(1, 0, 2)), new_mems)
outputs = (tf.transpose(output, perm=(1, 0, 2)),)
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
outputs = outputs + (new_mems,)
if self.output_hidden_states:
if output_g is not None:
hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
@ -653,7 +656,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
outputs = outputs + (attentions,)
return outputs # outputs, new_mems, (hidden_states), (attentions)
return outputs # outputs, (new_mems), (hidden_states), (attentions)
class TFXLNetPreTrainedModel(TFPreTrainedModel):
@ -768,7 +771,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**mems**:
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
@ -810,7 +813,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**prediction_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**mems**:
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
@ -854,7 +857,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # return logits, mems, (hidden states), (attentions)
return outputs # return logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
@ -865,7 +868,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**logits**: ``tf.Tensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**mems**:
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
@ -909,7 +912,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # return logits, mems, (hidden states), (attentions)
return outputs # return logits, (mems), (hidden states), (attentions)
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
@ -923,6 +926,11 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
Span-start scores (before SoftMax).
**end_scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
@ -962,7 +970,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel):
outputs = (start_logits, end_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
return outputs # start_logits, end_logits, (hidden_states), (attentions)
return outputs # start_logits, end_logits, (mems), (hidden_states), (attentions)
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
# the hidden-states output to compute `span start logits` and `span end logits`). """,

View File

@ -555,7 +555,7 @@ class XLNetModel(XLNetPreTrainedModel):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**mems**:
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
@ -581,6 +581,7 @@ class XLNetModel(XLNetPreTrainedModel):
super(XLNetModel, self).__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.output_past = config.output_past
self.mem_len = config.mem_len
self.reuse_len = config.reuse_len
@ -637,16 +638,13 @@ class XLNetModel(XLNetPreTrainedModel):
def cache_mem(self, curr_out, prev_mem):
"""cache hidden states into memory."""
if self.mem_len is None or self.mem_len == 0:
return None
else:
if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[:self.reuse_len]
if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[:self.reuse_len]
if prev_mem is None:
new_mem = curr_out[-self.mem_len:]
else:
new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len:]
if prev_mem is None:
new_mem = curr_out[-self.mem_len:]
else:
new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len:]
return new_mem.detach()
@ -817,8 +815,9 @@ class XLNetModel(XLNetPreTrainedModel):
attentions = []
hidden_states = []
for i, layer_module in enumerate(self.layer):
# cache new mems
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
# cache new mems
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
@ -836,7 +835,11 @@ class XLNetModel(XLNetPreTrainedModel):
output = self.dropout(output_g if output_g is not None else output_h)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs = (output.permute(1, 0, 2).contiguous(), new_mems)
outputs = (output.permute(1, 0, 2).contiguous(),)
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
outputs = outputs + (new_mems,)
if self.output_hidden_states:
if output_g is not None:
hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
@ -847,7 +850,7 @@ class XLNetModel(XLNetPreTrainedModel):
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs = outputs + (attentions,)
return outputs # outputs, new_mems, (hidden_states), (attentions)
return outputs # outputs, (new_mems), (hidden_states), (attentions)
@add_start_docstrings("""XLNet Model with a language modeling head on top
@ -867,7 +870,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
Language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**mems**:
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
@ -932,7 +935,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
labels.view(-1))
outputs = (loss,) + outputs
return outputs # return (loss), logits, mems, (hidden states), (attentions)
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
@ -951,7 +954,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**mems**:
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
@ -1011,7 +1014,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # return (loss), logits, mems, (hidden states), (attentions)
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RACE/SWAG tasks. """,
@ -1046,6 +1049,11 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
@ -1102,7 +1110,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
loss = loss_fct(reshaped_logits, labels.view(-1))
outputs = (loss,) + outputs
return outputs # return (loss), logits, mems, (hidden states), (attentions)
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
@ -1126,7 +1134,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**mems**:
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
@ -1197,7 +1205,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
return outputs # (loss), start_logits, end_logits, (mems), (hidden_states), (attentions)
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
@ -1239,7 +1247,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
**cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size,)``
Log probabilities for the ``is_impossible`` label of the answers.
**mems**:
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.

View File

@ -161,6 +161,10 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
"outputs": outputs.numpy(),
}
model.config.mem_len = 0
no_mems_outputs = model(inputs)
self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertListEqual(
list(result["outputs"].shape),
[self.batch_size, self.seq_length, self.hidden_size])

View File

@ -150,6 +150,10 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"outputs": outputs,
}
model.config.mem_len = 0
no_mems_outputs = model(input_ids_1)
self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertListEqual(
list(result["outputs"].size()),
[self.batch_size, self.seq_length, self.hidden_size])