[T5, generation] Add decoder caching for T5 (#3682)

* initial commit to add decoder caching for T5

* better naming for caching

* finish T5 decoder caching

* correct test

* added extensive past testing for T5

* clean files

* make tests cleaner

* improve docstring

* improve docstring

* better reorder cache

* make style

* Update src/transformers/modeling_t5.py

Co-Authored-By: Yacine Jernite <yjernite@users.noreply.github.com>

* make set output past work for all layers

* improve docstring

* improve docstring

Co-authored-by: Yacine Jernite <yjernite@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2020-04-10 01:02:50 +02:00 committed by GitHub
parent 9384e5f6de
commit ce2298fb5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 386 additions and 82 deletions

View File

@ -16,7 +16,6 @@
import copy
import itertools
import logging
import math
import os
@ -185,13 +184,11 @@ class T5LayerFF(nn.Module):
class T5Attention(nn.Module):
NEW_ID = itertools.count()
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.layer_id = next(T5Attention.NEW_ID)
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.output_past = config.output_past
self.output_attentions = config.output_attentions
self.relative_attention_num_buckets = config.relative_attention_num_buckets
@ -294,15 +291,37 @@ class T5Attention(nn.Module):
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
return values
def forward(self, input, mask=None, kv=None, position_bias=None, cache=None, head_mask=None):
def forward(
self,
input,
mask=None,
kv=None,
position_bias=None,
past_key_value_state=None,
head_mask=None,
query_length=None,
):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
"""
# Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
# past_key_value_state[0] is (bs, n_heads, q_len - 1, dim_per_head)
bs, qlen, dim = input.size()
if past_key_value_state is not None:
assert self.is_decoder is True, "Encoder cannot cache past key value states"
assert (
len(past_key_value_state) == 2
), "past_key_value_state should have 2 past states: keys and values. Got {} past states".format(
len(past_key_value_state)
)
real_qlen = qlen + past_key_value_state[0].shape[2] if query_length is None else query_length
else:
real_qlen = qlen
if kv is None:
klen = qlen if cache is None else cache["slen"] + qlen
klen = real_qlen
else:
klen = kv.size(1)
@ -315,23 +334,27 @@ class T5Attention(nn.Module):
return x.transpose(1, 2).contiguous().view(bs, -1, self.inner_dim)
q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head)
if kv is None:
k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head)
elif cache is None or self.layer_id not in cache:
elif past_key_value_state is None:
k = v = kv
k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head)
if cache is not None:
if self.layer_id in cache:
if kv is None:
k_, v_ = cache[self.layer_id]
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v)
if past_key_value_state is not None:
if kv is None:
k_, v_ = past_key_value_state
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = past_key_value_state
if self.is_decoder and self.output_past:
present_key_value_state = ((k, v),)
else:
present_key_value_state = (None,)
# q = q / math.sqrt(dim_per_head) # No scaling in T5
scores = torch.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen)
@ -339,7 +362,13 @@ class T5Attention(nn.Module):
if position_bias is None:
if not self.has_relative_attention_bias:
raise ValueError("No position_bias provided and no weights to compute position_bias")
position_bias = self.compute_bias(qlen, klen)
position_bias = self.compute_bias(real_qlen, klen)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value_state is not None:
position_bias = position_bias[:, :, -1:, :]
if mask is not None:
position_bias = position_bias + mask # (bs, n_heads, qlen, klen)
@ -357,6 +386,13 @@ class T5Attention(nn.Module):
context = self.o(context)
outputs = (context,)
if self.output_past is False or self.is_decoder is False:
assert (
present_key_value_state[0] is None
), "Key/Value projections should not be stored if {} is not decoder or output_past is False".format(self)
outputs = outputs + present_key_value_state
if self.output_attentions:
outputs = outputs + (weights,)
if self.has_relative_attention_bias:
@ -371,10 +407,16 @@ class T5LayerSelfAttention(nn.Module):
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states, attention_mask=None, position_bias=None, head_mask=None):
def forward(
self, hidden_states, attention_mask=None, position_bias=None, head_mask=None, past_key_value_state=None
):
norm_x = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
norm_x, mask=attention_mask, position_bias=position_bias, head_mask=head_mask
norm_x,
mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=past_key_value_state,
)
y = attention_output[0]
layer_output = hidden_states + self.dropout(y)
@ -389,10 +431,25 @@ class T5LayerCrossAttention(nn.Module):
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None):
def forward(
self,
hidden_states,
kv,
attention_mask=None,
position_bias=None,
head_mask=None,
past_key_value_state=None,
query_length=None,
):
norm_x = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
norm_x, mask=attention_mask, kv=kv, position_bias=position_bias, head_mask=head_mask
norm_x,
mask=attention_mask,
kv=kv,
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=past_key_value_state,
query_length=query_length,
)
y = attention_output[0]
layer_output = hidden_states + self.dropout(y)
@ -403,14 +460,14 @@ class T5LayerCrossAttention(nn.Module):
class T5Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.output_past = config.output_past
self.is_decoder = config.is_decoder
self.layer = nn.ModuleList()
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
if self.is_decoder:
self.layer.append(T5LayerCrossAttention(config, has_relative_attention_bias=has_relative_attention_bias))
self.layer.append(T5LayerFF(config))
else:
self.layer.append(T5LayerFF(config))
self.layer.append(T5LayerFF(config))
def forward(
self,
@ -421,31 +478,63 @@ class T5Block(nn.Module):
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
head_mask=None,
past_key_value_state=None,
):
self_attention_outputs = self.layer[0](
hidden_states, attention_mask=attention_mask, position_bias=position_bias, head_mask=head_mask
)
hidden_states = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
if not self.is_decoder:
hidden_states = self.layer[1](hidden_states)
if past_key_value_state is not None:
assert self.is_decoder, "Only decoder can use `past_key_value_states`"
assert (
len(past_key_value_state) == 4
), "The should be 4 past states. 2 (past / key) for self attention. 2 (past / key) for cross attention. Got {} past key / value states".format(
len(past_key_value_state)
)
self_attn_past_key_value_state = past_key_value_state[:2]
cross_attn_past_key_value_state = past_key_value_state[2:]
else:
self_attn_past_key_value_state, cross_attn_past_key_value_state = None, None
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=self_attn_past_key_value_state,
)
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
if self.is_decoder:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs = self.layer[1](
hidden_states,
kv=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
head_mask=head_mask,
past_key_value_state=cross_attn_past_key_value_state,
query_length=query_length,
)
hidden_states = cross_attention_outputs[0]
outputs = (
outputs + cross_attention_outputs[1:]
) # Keep cross-attention outputs and relative position weights
hidden_states = self.layer[2](hidden_states)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
outputs = (hidden_states,) + outputs # add attentions if we output them
return outputs # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
# Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:]
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
outputs = (hidden_states,)
# Add attentions if we output them
outputs = outputs + (present_key_value_state,) + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
class T5PreTrainedModel(PreTrainedModel):
@ -531,6 +620,7 @@ class T5Stack(T5PreTrainedModel):
self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder
self.output_past = config.output_past and self.is_decoder
self.block = nn.ModuleList(
[T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
@ -557,6 +647,7 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
past_key_value_states=None,
):
if input_ids is not None and inputs_embeds is not None:
@ -575,25 +666,41 @@ class T5Stack(T5PreTrainedModel):
batch_size, seq_length = input_shape
if past_key_value_states is not None:
assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format(
input_shape, (batch_size, 1)
)
# required mask seq length can be calculated via length of past
# key value states and seq_length = 1 for the last token
mask_seq_length = past_key_value_states[0][0].shape[2] + seq_length
else:
mask_seq_length = seq_length
if attention_mask is None:
attention_mask = torch.ones(batch_size, seq_length).to(inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(inputs_embeds.device)
# initialize past_key_value_states with `None` if past does not exist
if past_key_value_states is None:
past_key_value_states = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
if self.config.is_decoder:
seq_ids = torch.arange(seq_length, device=inputs_embeds.device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
seq_ids = torch.arange(mask_seq_length, device=inputs_embeds.device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, mask_seq_length, 1) <= seq_ids[None, :, None]
causal_mask = causal_mask.to(attention_mask)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
if self.output_past and past_key_value_states[0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -1:, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
@ -610,9 +717,9 @@ class T5Stack(T5PreTrainedModel):
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
if self.is_decoder:
if self.is_decoder and encoder_attention_mask is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
# we need to make broadcastabe to [batch_size, num_heads, mask_seq_length, mask_seq_length]
if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.dim() == 2:
@ -633,7 +740,7 @@ class T5Stack(T5PreTrainedModel):
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x mask_seq_length x mask_seq_length]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
@ -648,13 +755,15 @@ class T5Stack(T5PreTrainedModel):
else:
head_mask = [None] * self.config.num_layers
present_key_value_states = ()
all_hidden_states = ()
all_attentions = ()
position_bias = None
encoder_decoder_position_bias = None
hidden_states = self.dropout(inputs_embeds)
for i, layer_module in enumerate(self.block):
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -666,19 +775,22 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i],
past_key_value_state=past_key_value_state,
)
# layer_outputs is a tuple with:
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states = layer_outputs[0]
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2]
if i == 0:
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[2 if self.output_attentions else 1]
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[3 if self.output_attentions else 2]
if self.is_decoder:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2]
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3]
# append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,)
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) # We keep only self-attention weights for now
all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
@ -688,11 +800,13 @@ class T5Stack(T5PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.is_decoder and self.output_past:
outputs = outputs + (present_key_value_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
return outputs # last-layer hidden state, (presents,) (all hidden states), (all attentions)
T5_START_DOCSTRING = r""" The T5 model was proposed in
@ -719,7 +833,7 @@ T5_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
T5 is a model with relative position embeddings so you should be able to pad the inputs on
T5 is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left. If `decoder_past_key_value_states` is used, optionally only the last `input_ids` have to be input (see `decoder_past_key_value_states`).
Indices can be obtained using :class:`transformers.T5Tokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
@ -739,6 +853,9 @@ T5_INPUTS_DOCSTRING = r"""
`T5 Training <./t5.html#training>`_ .
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding. If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids` of shape :obj:`(batch_size, 1)` instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
@ -780,6 +897,20 @@ class T5Model(T5PreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def set_output_past(self, do_output_past: bool):
self.config.output_past = do_output_past
self.decoder.output_past = do_output_past
for block in self.decoder.block:
block.output_past = do_output_past
block.layer[0].SelfAttention.output_past = do_output_past
block.layer[1].EncDecAttention.output_past = do_output_past
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
@ -796,6 +927,7 @@ class T5Model(T5PreTrainedModel):
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_past_key_value_states=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
head_mask=None,
@ -805,6 +937,11 @@ class T5Model(T5PreTrainedModel):
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs.
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `decoder_past_key_value_states` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``config.output_past=True``):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -837,16 +974,29 @@ class T5Model(T5PreTrainedModel):
hidden_states = encoder_outputs[0]
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_value_states is not None and self.decoder.output_past is True:
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_value_states=decoder_past_key_value_states,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
)
if self.decoder.output_past:
past = ((encoder_outputs, decoder_outputs[1]),)
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
return decoder_outputs + encoder_outputs
@ -872,6 +1022,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
def get_input_embeddings(self):
return self.shared
def set_output_past(self, do_output_past: bool):
self.config.output_past = do_output_past
self.decoder.output_past = do_output_past
for block in self.decoder.block:
block.output_past = do_output_past
block.layer[0].SelfAttention.output_past = do_output_past
block.layer[1].EncDecAttention.output_past = do_output_past
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
@ -883,6 +1041,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
def forward(
self,
@ -891,6 +1052,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_past_key_value_states=None,
lm_labels=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
@ -909,10 +1071,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
Classification loss (cross entropy).
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
If `past_key_value_states` is used only the last prediction_scores of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``config.output_past=True``):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
@ -948,16 +1114,34 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(lm_labels)
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_value_states is not None and self.decoder.output_past is True:
assert (
lm_labels is None
), "Decoder should not use cached key value states when training. Also consider setting model.set_output_past(False) for less memory consumption"
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_value_states=decoder_past_key_value_states,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
)
# insert decoder past at right place
# to speed up decoding
if self.decoder.output_past:
past = ((encoder_outputs, decoder_outputs[1]),)
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
sequence_output = decoder_outputs[0]
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
@ -968,9 +1152,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
if lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
decoder_outputs = (
loss,
) + decoder_outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
decoder_outputs = (loss,) + decoder_outputs
return decoder_outputs + encoder_outputs
@ -978,17 +1161,40 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
assert past is not None, "past has to be defined for encoder_outputs"
# first step
if type(past) is tuple:
encoder_outputs = past
if len(past) < 2:
encoder_outputs, decoder_past_key_value_states = past, None
else:
encoder_outputs = (past,)
encoder_outputs, decoder_past_key_value_states = past[0], past[1]
return {
"decoder_input_ids": input_ids,
"decoder_past_key_value_states": decoder_past_key_value_states,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
}
def _reorder_cache(self, past, beam_idx):
# past does not have to be re-ordered for T5.
return past
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if len(past) < 2:
logger.warning("You might want to consider setting model.set_output_past(True) to speed up decoding")
return past
decoder_past = past[1]
past = (past[0],)
reordered_decoder_past = ()
for layer_past_states in decoder_past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx),
)
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return past + (reordered_decoder_past,)

View File

@ -1417,17 +1417,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
return past
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):

View File

@ -128,6 +128,7 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
config.output_attentions = True
config.output_hidden_states = False
config.output_past = False
model = model_class(config)
model.to(torch_device)
model.eval()
@ -144,10 +145,9 @@ class ModelTesterMixin:
out_len = len(outputs)
if self.is_encoder_decoder:
correct_outlen = (
4 # decoder_features_or_logits, decoder_attentions, encoder_features, encoder_attentions
)
correct_outlen = 4
decoder_attention_idx = 1
if "lm_labels" in inputs_dict: # loss will come first
correct_outlen += 1 # compute loss
decoder_attention_idx += 1

View File

@ -167,17 +167,20 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
model = T5Model(config=config)
model.to(torch_device)
model.eval()
decoder_output, encoder_output = model(
decoder_output, decoder_past, encoder_output = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
decoder_output, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
decoder_output, decoder_past, encoder_output = model(
input_ids=input_ids, decoder_input_ids=decoder_input_ids
)
result = {
"encoder_output": encoder_output,
"decoder_output": decoder_output,
"decoder_past": decoder_past,
}
self.parent.assertListEqual(
list(result["encoder_output"].size()), [self.batch_size, self.encoder_seq_length, self.hidden_size]
@ -185,6 +188,13 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(
list(result["decoder_output"].size()), [self.batch_size, self.decoder_seq_length, self.hidden_size]
)
self.parent.assertEqual(len(decoder_past), 2)
# decoder_past[0] should correspond to encoder output
self.parent.assertTrue(torch.all(decoder_past[0][0] == encoder_output))
# There should be `num_layers` key value embeddings stored in decoder_past[1]
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
self.parent.assertEqual(len(decoder_past[1][0]), 4)
def create_and_check_t5_with_lm_head(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
@ -198,8 +208,8 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
decoder_attention_mask=decoder_attention_mask,
lm_labels=lm_labels,
)
loss, prediction_scores, encoder_features = outputs
self.parent.assertEqual(len(outputs), 3)
loss, prediction_scores, _, _ = outputs
self.parent.assertEqual(len(outputs), 4)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
@ -209,6 +219,92 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
)
self.check_loss_output(result)
def create_and_check_t5_decoder_model_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config).get_decoder()
model.to(torch_device)
model.eval()
# first forward pass
output, past_key_value_states = model(input_ids)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past, _ = model(next_input_ids)
output_from_past, _ = model(next_tokens, past_key_value_states=past_key_value_states)
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_t5_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config).get_decoder()
model.to(torch_device)
model.eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = input_ids.shape[-1] // 2
attn_mask[:, half_seq_length:] = 0
# first forward pass
output, past_key_value_states = model(input_ids, attention_mask=attn_mask)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1,
)
# get two different outputs
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)
output_from_past, _ = model(
next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask
)
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_t5_and_check_t5_generate_with_past_key_value_states(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
config.num_layers = 1
model = T5ForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
torch.manual_seed(0)
model.set_output_past(False)
output_without_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
torch.manual_seed(0)
model.set_output_past(True)
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@ -247,6 +343,18 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
def test_t5_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs)
def test_t5_decoder_model_past_with_attn_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs)
def test_t5_generate_with_past_key_value_states(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: