mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[T5, generation] Add decoder caching for T5 (#3682)
* initial commit to add decoder caching for T5 * better naming for caching * finish T5 decoder caching * correct test * added extensive past testing for T5 * clean files * make tests cleaner * improve docstring * improve docstring * better reorder cache * make style * Update src/transformers/modeling_t5.py Co-Authored-By: Yacine Jernite <yjernite@users.noreply.github.com> * make set output past work for all layers * improve docstring * improve docstring Co-authored-by: Yacine Jernite <yjernite@users.noreply.github.com>
This commit is contained in:
parent
9384e5f6de
commit
ce2298fb5f
@ -16,7 +16,6 @@
|
||||
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@ -185,13 +184,11 @@ class T5LayerFF(nn.Module):
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
NEW_ID = itertools.count()
|
||||
|
||||
def __init__(self, config, has_relative_attention_bias=False):
|
||||
super().__init__()
|
||||
self.layer_id = next(T5Attention.NEW_ID)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
self.output_past = config.output_past
|
||||
|
||||
self.output_attentions = config.output_attentions
|
||||
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||
@ -294,15 +291,37 @@ class T5Attention(nn.Module):
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
|
||||
return values
|
||||
|
||||
def forward(self, input, mask=None, kv=None, position_bias=None, cache=None, head_mask=None):
|
||||
def forward(
|
||||
self,
|
||||
input,
|
||||
mask=None,
|
||||
kv=None,
|
||||
position_bias=None,
|
||||
past_key_value_state=None,
|
||||
head_mask=None,
|
||||
query_length=None,
|
||||
):
|
||||
"""
|
||||
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
||||
"""
|
||||
# Input is (bs, qlen, dim)
|
||||
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
|
||||
# past_key_value_state[0] is (bs, n_heads, q_len - 1, dim_per_head)
|
||||
bs, qlen, dim = input.size()
|
||||
|
||||
if past_key_value_state is not None:
|
||||
assert self.is_decoder is True, "Encoder cannot cache past key value states"
|
||||
assert (
|
||||
len(past_key_value_state) == 2
|
||||
), "past_key_value_state should have 2 past states: keys and values. Got {} past states".format(
|
||||
len(past_key_value_state)
|
||||
)
|
||||
real_qlen = qlen + past_key_value_state[0].shape[2] if query_length is None else query_length
|
||||
else:
|
||||
real_qlen = qlen
|
||||
|
||||
if kv is None:
|
||||
klen = qlen if cache is None else cache["slen"] + qlen
|
||||
klen = real_qlen
|
||||
else:
|
||||
klen = kv.size(1)
|
||||
|
||||
@ -315,23 +334,27 @@ class T5Attention(nn.Module):
|
||||
return x.transpose(1, 2).contiguous().view(bs, -1, self.inner_dim)
|
||||
|
||||
q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
|
||||
if kv is None:
|
||||
k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
elif cache is None or self.layer_id not in cache:
|
||||
elif past_key_value_state is None:
|
||||
k = v = kv
|
||||
k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head)
|
||||
v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head)
|
||||
|
||||
if cache is not None:
|
||||
if self.layer_id in cache:
|
||||
if kv is None:
|
||||
k_, v_ = cache[self.layer_id]
|
||||
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
|
||||
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
|
||||
else:
|
||||
k, v = cache[self.layer_id]
|
||||
cache[self.layer_id] = (k, v)
|
||||
if past_key_value_state is not None:
|
||||
if kv is None:
|
||||
k_, v_ = past_key_value_state
|
||||
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
|
||||
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
|
||||
else:
|
||||
k, v = past_key_value_state
|
||||
|
||||
if self.is_decoder and self.output_past:
|
||||
present_key_value_state = ((k, v),)
|
||||
else:
|
||||
present_key_value_state = (None,)
|
||||
|
||||
# q = q / math.sqrt(dim_per_head) # No scaling in T5
|
||||
scores = torch.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen)
|
||||
@ -339,7 +362,13 @@ class T5Attention(nn.Module):
|
||||
if position_bias is None:
|
||||
if not self.has_relative_attention_bias:
|
||||
raise ValueError("No position_bias provided and no weights to compute position_bias")
|
||||
position_bias = self.compute_bias(qlen, klen)
|
||||
position_bias = self.compute_bias(real_qlen, klen)
|
||||
|
||||
# if key and values are already calculated
|
||||
# we want only the last query position bias
|
||||
if past_key_value_state is not None:
|
||||
position_bias = position_bias[:, :, -1:, :]
|
||||
|
||||
if mask is not None:
|
||||
position_bias = position_bias + mask # (bs, n_heads, qlen, klen)
|
||||
|
||||
@ -357,6 +386,13 @@ class T5Attention(nn.Module):
|
||||
context = self.o(context)
|
||||
|
||||
outputs = (context,)
|
||||
|
||||
if self.output_past is False or self.is_decoder is False:
|
||||
assert (
|
||||
present_key_value_state[0] is None
|
||||
), "Key/Value projections should not be stored if {} is not decoder or output_past is False".format(self)
|
||||
|
||||
outputs = outputs + present_key_value_state
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (weights,)
|
||||
if self.has_relative_attention_bias:
|
||||
@ -371,10 +407,16 @@ class T5LayerSelfAttention(nn.Module):
|
||||
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, position_bias=None, head_mask=None):
|
||||
def forward(
|
||||
self, hidden_states, attention_mask=None, position_bias=None, head_mask=None, past_key_value_state=None
|
||||
):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
attention_output = self.SelfAttention(
|
||||
norm_x, mask=attention_mask, position_bias=position_bias, head_mask=head_mask
|
||||
norm_x,
|
||||
mask=attention_mask,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
past_key_value_state=past_key_value_state,
|
||||
)
|
||||
y = attention_output[0]
|
||||
layer_output = hidden_states + self.dropout(y)
|
||||
@ -389,10 +431,25 @@ class T5LayerCrossAttention(nn.Module):
|
||||
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
kv,
|
||||
attention_mask=None,
|
||||
position_bias=None,
|
||||
head_mask=None,
|
||||
past_key_value_state=None,
|
||||
query_length=None,
|
||||
):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
attention_output = self.EncDecAttention(
|
||||
norm_x, mask=attention_mask, kv=kv, position_bias=position_bias, head_mask=head_mask
|
||||
norm_x,
|
||||
mask=attention_mask,
|
||||
kv=kv,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
past_key_value_state=past_key_value_state,
|
||||
query_length=query_length,
|
||||
)
|
||||
y = attention_output[0]
|
||||
layer_output = hidden_states + self.dropout(y)
|
||||
@ -403,14 +460,14 @@ class T5LayerCrossAttention(nn.Module):
|
||||
class T5Block(nn.Module):
|
||||
def __init__(self, config, has_relative_attention_bias=False):
|
||||
super().__init__()
|
||||
self.output_past = config.output_past
|
||||
self.is_decoder = config.is_decoder
|
||||
self.layer = nn.ModuleList()
|
||||
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
|
||||
if self.is_decoder:
|
||||
self.layer.append(T5LayerCrossAttention(config, has_relative_attention_bias=has_relative_attention_bias))
|
||||
self.layer.append(T5LayerFF(config))
|
||||
else:
|
||||
self.layer.append(T5LayerFF(config))
|
||||
|
||||
self.layer.append(T5LayerFF(config))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -421,31 +478,63 @@ class T5Block(nn.Module):
|
||||
encoder_attention_mask=None,
|
||||
encoder_decoder_position_bias=None,
|
||||
head_mask=None,
|
||||
past_key_value_state=None,
|
||||
):
|
||||
self_attention_outputs = self.layer[0](
|
||||
hidden_states, attention_mask=attention_mask, position_bias=position_bias, head_mask=head_mask
|
||||
)
|
||||
hidden_states = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
|
||||
|
||||
if not self.is_decoder:
|
||||
hidden_states = self.layer[1](hidden_states)
|
||||
if past_key_value_state is not None:
|
||||
assert self.is_decoder, "Only decoder can use `past_key_value_states`"
|
||||
assert (
|
||||
len(past_key_value_state) == 4
|
||||
), "The should be 4 past states. 2 (past / key) for self attention. 2 (past / key) for cross attention. Got {} past key / value states".format(
|
||||
len(past_key_value_state)
|
||||
)
|
||||
self_attn_past_key_value_state = past_key_value_state[:2]
|
||||
cross_attn_past_key_value_state = past_key_value_state[2:]
|
||||
else:
|
||||
self_attn_past_key_value_state, cross_attn_past_key_value_state = None, None
|
||||
|
||||
self_attention_outputs = self.layer[0](
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
past_key_value_state=self_attn_past_key_value_state,
|
||||
)
|
||||
hidden_states, present_key_value_state = self_attention_outputs[:2]
|
||||
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
|
||||
|
||||
if self.is_decoder:
|
||||
# the actual query length is unknown for cross attention
|
||||
# if using past key value states. Need to inject it here
|
||||
if present_key_value_state is not None:
|
||||
query_length = present_key_value_state[0].shape[2]
|
||||
else:
|
||||
query_length = None
|
||||
|
||||
cross_attention_outputs = self.layer[1](
|
||||
hidden_states,
|
||||
kv=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
position_bias=encoder_decoder_position_bias,
|
||||
head_mask=head_mask,
|
||||
past_key_value_state=cross_attn_past_key_value_state,
|
||||
query_length=query_length,
|
||||
)
|
||||
hidden_states = cross_attention_outputs[0]
|
||||
outputs = (
|
||||
outputs + cross_attention_outputs[1:]
|
||||
) # Keep cross-attention outputs and relative position weights
|
||||
hidden_states = self.layer[2](hidden_states)
|
||||
# Combine self attn and cross attn key value states
|
||||
if present_key_value_state is not None:
|
||||
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
|
||||
|
||||
outputs = (hidden_states,) + outputs # add attentions if we output them
|
||||
return outputs # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
# Keep cross-attention outputs and relative position weights
|
||||
attention_outputs = attention_outputs + cross_attention_outputs[2:]
|
||||
|
||||
# Apply Feed Forward layer
|
||||
hidden_states = self.layer[-1](hidden_states)
|
||||
outputs = (hidden_states,)
|
||||
|
||||
# Add attentions if we output them
|
||||
outputs = outputs + (present_key_value_state,) + attention_outputs
|
||||
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
|
||||
|
||||
class T5PreTrainedModel(PreTrainedModel):
|
||||
@ -531,6 +620,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
self.is_decoder = config.is_decoder
|
||||
self.output_past = config.output_past and self.is_decoder
|
||||
|
||||
self.block = nn.ModuleList(
|
||||
[T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
|
||||
@ -557,6 +647,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
encoder_attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
head_mask=None,
|
||||
past_key_value_states=None,
|
||||
):
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
@ -575,25 +666,41 @@ class T5Stack(T5PreTrainedModel):
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
if past_key_value_states is not None:
|
||||
assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format(
|
||||
input_shape, (batch_size, 1)
|
||||
)
|
||||
# required mask seq length can be calculated via length of past
|
||||
# key value states and seq_length = 1 for the last token
|
||||
mask_seq_length = past_key_value_states[0][0].shape[2] + seq_length
|
||||
else:
|
||||
mask_seq_length = seq_length
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, seq_length).to(inputs_embeds.device)
|
||||
if self.is_decoder and encoder_attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
|
||||
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
||||
encoder_seq_length = encoder_hidden_states.shape[1]
|
||||
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(inputs_embeds.device)
|
||||
|
||||
# initialize past_key_value_states with `None` if past does not exist
|
||||
if past_key_value_states is None:
|
||||
past_key_value_states = [None] * len(self.block)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask[:, None, :, :]
|
||||
elif attention_mask.dim() == 2:
|
||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
|
||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||
if self.config.is_decoder:
|
||||
seq_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
seq_ids = torch.arange(mask_seq_length, device=inputs_embeds.device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, mask_seq_length, 1) <= seq_ids[None, :, None]
|
||||
causal_mask = causal_mask.to(attention_mask)
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
if self.output_past and past_key_value_states[0] is not None:
|
||||
extended_attention_mask = extended_attention_mask[:, :, -1:, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
@ -610,9 +717,9 @@ class T5Stack(T5PreTrainedModel):
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
|
||||
|
||||
if self.is_decoder:
|
||||
if self.is_decoder and encoder_attention_mask is not None:
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
# we need to make broadcastabe to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.dim() == 2:
|
||||
@ -633,7 +740,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x mask_seq_length x mask_seq_length]
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
@ -648,13 +755,15 @@ class T5Stack(T5PreTrainedModel):
|
||||
else:
|
||||
head_mask = [None] * self.config.num_layers
|
||||
|
||||
present_key_value_states = ()
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
position_bias = None
|
||||
encoder_decoder_position_bias = None
|
||||
|
||||
hidden_states = self.dropout(inputs_embeds)
|
||||
for i, layer_module in enumerate(self.block):
|
||||
|
||||
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
@ -666,19 +775,22 @@ class T5Stack(T5PreTrainedModel):
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
head_mask=head_mask[i],
|
||||
past_key_value_state=past_key_value_state,
|
||||
)
|
||||
# layer_outputs is a tuple with:
|
||||
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
hidden_states = layer_outputs[0]
|
||||
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
hidden_states, present_key_value_state = layer_outputs[:2]
|
||||
if i == 0:
|
||||
# We share the position biases between the layers - the first layer store them
|
||||
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
position_bias = layer_outputs[2 if self.output_attentions else 1]
|
||||
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
position_bias = layer_outputs[3 if self.output_attentions else 2]
|
||||
if self.is_decoder:
|
||||
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2]
|
||||
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3]
|
||||
# append next layer key value states
|
||||
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],) # We keep only self-attention weights for now
|
||||
all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now
|
||||
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
@ -688,11 +800,13 @@ class T5Stack(T5PreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.is_decoder and self.output_past:
|
||||
outputs = outputs + (present_key_value_states,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
return outputs # last-layer hidden state, (presents,) (all hidden states), (all attentions)
|
||||
|
||||
|
||||
T5_START_DOCSTRING = r""" The T5 model was proposed in
|
||||
@ -719,7 +833,7 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
T5 is a model with relative position embeddings so you should be able to pad the inputs on
|
||||
T5 is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left. If `decoder_past_key_value_states` is used, optionally only the last `input_ids` have to be input (see `decoder_past_key_value_states`).
|
||||
Indices can be obtained using :class:`transformers.T5Tokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
@ -739,6 +853,9 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
`T5 Training <./t5.html#training>`_ .
|
||||
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
|
||||
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
|
||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||
Can be used to speed up decoding. If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids` of shape :obj:`(batch_size, 1)` instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
@ -780,6 +897,20 @@ class T5Model(T5PreTrainedModel):
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def set_output_past(self, do_output_past: bool):
|
||||
self.config.output_past = do_output_past
|
||||
self.decoder.output_past = do_output_past
|
||||
for block in self.decoder.block:
|
||||
block.output_past = do_output_past
|
||||
block.layer[0].SelfAttention.output_past = do_output_past
|
||||
block.layer[1].EncDecAttention.output_past = do_output_past
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@ -796,6 +927,7 @@ class T5Model(T5PreTrainedModel):
|
||||
encoder_outputs=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_past_key_value_states=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
head_mask=None,
|
||||
@ -805,6 +937,11 @@ class T5Model(T5PreTrainedModel):
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs.
|
||||
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
If `decoder_past_key_value_states` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
|
||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``config.output_past=True``):
|
||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
|
||||
Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
@ -837,16 +974,29 @@ class T5Model(T5PreTrainedModel):
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
# If decoding with past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if decoder_past_key_value_states is not None and self.decoder.output_past is True:
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
if decoder_inputs_embeds is not None:
|
||||
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
|
||||
|
||||
# Decode
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
past_key_value_states=decoder_past_key_value_states,
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
)
|
||||
|
||||
if self.decoder.output_past:
|
||||
past = ((encoder_outputs, decoder_outputs[1]),)
|
||||
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
|
||||
@ -872,6 +1022,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_output_past(self, do_output_past: bool):
|
||||
self.config.output_past = do_output_past
|
||||
self.decoder.output_past = do_output_past
|
||||
for block in self.decoder.block:
|
||||
block.output_past = do_output_past
|
||||
block.layer[0].SelfAttention.output_past = do_output_past
|
||||
block.layer[1].EncDecAttention.output_past = do_output_past
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared = new_embeddings
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
@ -883,6 +1041,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -891,6 +1052,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
encoder_outputs=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_past_key_value_states=None,
|
||||
lm_labels=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
@ -909,10 +1071,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
Classification loss (cross entropy).
|
||||
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
If `past_key_value_states` is used only the last prediction_scores of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
|
||||
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``config.output_past=True``):
|
||||
Contains pre-computed key and value hidden-states of the attention blocks.
|
||||
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
|
||||
Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
@ -948,16 +1114,34 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
# get decoder inputs from shifting lm labels to the right
|
||||
decoder_input_ids = self._shift_right(lm_labels)
|
||||
|
||||
# If decoding with past key value states, only the last tokens
|
||||
# should be given as an input
|
||||
if decoder_past_key_value_states is not None and self.decoder.output_past is True:
|
||||
assert (
|
||||
lm_labels is None
|
||||
), "Decoder should not use cached key value states when training. Also consider setting model.set_output_past(False) for less memory consumption"
|
||||
if decoder_input_ids is not None:
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
if decoder_inputs_embeds is not None:
|
||||
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
|
||||
|
||||
# Decode
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
inputs_embeds=decoder_inputs_embeds,
|
||||
past_key_value_states=decoder_past_key_value_states,
|
||||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
)
|
||||
|
||||
# insert decoder past at right place
|
||||
# to speed up decoding
|
||||
if self.decoder.output_past:
|
||||
past = ((encoder_outputs, decoder_outputs[1]),)
|
||||
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
|
||||
|
||||
sequence_output = decoder_outputs[0]
|
||||
# Rescale output before projecting on vocab
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
||||
@ -968,9 +1152,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
if lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
|
||||
decoder_outputs = (
|
||||
loss,
|
||||
) + decoder_outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
||||
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
||||
decoder_outputs = (loss,) + decoder_outputs
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
@ -978,17 +1161,40 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
assert past is not None, "past has to be defined for encoder_outputs"
|
||||
|
||||
# first step
|
||||
if type(past) is tuple:
|
||||
encoder_outputs = past
|
||||
if len(past) < 2:
|
||||
encoder_outputs, decoder_past_key_value_states = past, None
|
||||
else:
|
||||
encoder_outputs = (past,)
|
||||
encoder_outputs, decoder_past_key_value_states = past[0], past[1]
|
||||
|
||||
return {
|
||||
"decoder_input_ids": input_ids,
|
||||
"decoder_past_key_value_states": decoder_past_key_value_states,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
# past does not have to be re-ordered for T5.
|
||||
return past
|
||||
# if decoder past is not included in output
|
||||
# speedy decoding is disabled and no need to reorder
|
||||
if len(past) < 2:
|
||||
logger.warning("You might want to consider setting model.set_output_past(True) to speed up decoding")
|
||||
return past
|
||||
|
||||
decoder_past = past[1]
|
||||
past = (past[0],)
|
||||
reordered_decoder_past = ()
|
||||
for layer_past_states in decoder_past:
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` is at 2nd position
|
||||
reordered_layer_past_states = ()
|
||||
for layer_past_state in layer_past_states:
|
||||
# need to set correct `past` for each of the four key / value states
|
||||
reordered_layer_past_states = reordered_layer_past_states + (
|
||||
layer_past_state.index_select(0, beam_idx),
|
||||
)
|
||||
|
||||
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
|
||||
assert len(reordered_layer_past_states) == len(layer_past_states)
|
||||
|
||||
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
|
||||
return past + (reordered_decoder_past,)
|
||||
|
@ -1417,17 +1417,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
reordered_past = []
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` and `mems` is at 2nd position
|
||||
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||
# check that shape matches
|
||||
assert reordered_layer_past.shape == layer_past.shape
|
||||
reordered_past.append(reordered_layer_past)
|
||||
past = tuple(reordered_past)
|
||||
return past
|
||||
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
|
||||
|
||||
|
||||
def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
|
||||
|
@ -128,6 +128,7 @@ class ModelTesterMixin:
|
||||
for model_class in self.all_model_classes:
|
||||
config.output_attentions = True
|
||||
config.output_hidden_states = False
|
||||
config.output_past = False
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@ -144,10 +145,9 @@ class ModelTesterMixin:
|
||||
out_len = len(outputs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
correct_outlen = (
|
||||
4 # decoder_features_or_logits, decoder_attentions, encoder_features, encoder_attentions
|
||||
)
|
||||
correct_outlen = 4
|
||||
decoder_attention_idx = 1
|
||||
|
||||
if "lm_labels" in inputs_dict: # loss will come first
|
||||
correct_outlen += 1 # compute loss
|
||||
decoder_attention_idx += 1
|
||||
|
@ -167,17 +167,20 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = T5Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
decoder_output, encoder_output = model(
|
||||
decoder_output, decoder_past, encoder_output = model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
decoder_output, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
decoder_output, decoder_past, encoder_output = model(
|
||||
input_ids=input_ids, decoder_input_ids=decoder_input_ids
|
||||
)
|
||||
|
||||
result = {
|
||||
"encoder_output": encoder_output,
|
||||
"decoder_output": decoder_output,
|
||||
"decoder_past": decoder_past,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["encoder_output"].size()), [self.batch_size, self.encoder_seq_length, self.hidden_size]
|
||||
@ -185,6 +188,13 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.parent.assertListEqual(
|
||||
list(result["decoder_output"].size()), [self.batch_size, self.decoder_seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertEqual(len(decoder_past), 2)
|
||||
# decoder_past[0] should correspond to encoder output
|
||||
self.parent.assertTrue(torch.all(decoder_past[0][0] == encoder_output))
|
||||
# There should be `num_layers` key value embeddings stored in decoder_past[1]
|
||||
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
|
||||
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
|
||||
self.parent.assertEqual(len(decoder_past[1][0]), 4)
|
||||
|
||||
def create_and_check_t5_with_lm_head(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
@ -198,8 +208,8 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
lm_labels=lm_labels,
|
||||
)
|
||||
loss, prediction_scores, encoder_features = outputs
|
||||
self.parent.assertEqual(len(outputs), 3)
|
||||
loss, prediction_scores, _, _ = outputs
|
||||
self.parent.assertEqual(len(outputs), 4)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
@ -209,6 +219,92 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_t5_decoder_model_past(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
model = T5Model(config=config).get_decoder()
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
output, past_key_value_states = model(input_ids)
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
|
||||
output_from_no_past, _ = model(next_input_ids)
|
||||
output_from_past, _ = model(next_tokens, past_key_value_states=past_key_value_states)
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_t5_decoder_model_attention_mask_past(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
model = T5Model(config=config).get_decoder()
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# create attention mask
|
||||
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
|
||||
half_seq_length = input_ids.shape[-1] // 2
|
||||
attn_mask[:, half_seq_length:] = 0
|
||||
|
||||
# first forward pass
|
||||
output, past_key_value_states = model(input_ids, attention_mask=attn_mask)
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
# change a random masked slice from input_ids
|
||||
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
||||
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
||||
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
||||
|
||||
# append to next input_ids and attn_mask
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1,
|
||||
)
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)
|
||||
output_from_past, _ = model(
|
||||
next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask
|
||||
)
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_t5_and_check_t5_generate_with_past_key_value_states(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
|
||||
):
|
||||
config.num_layers = 1
|
||||
model = T5ForConditionalGeneration(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
torch.manual_seed(0)
|
||||
model.set_output_past(False)
|
||||
output_without_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
|
||||
torch.manual_seed(0)
|
||||
model.set_output_past(True)
|
||||
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
|
||||
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@ -247,6 +343,18 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
|
||||
|
||||
def test_t5_decoder_model_past(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs)
|
||||
|
||||
def test_t5_decoder_model_past_with_attn_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs)
|
||||
|
||||
def test_t5_generate_with_past_key_value_states(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
|
Loading…
Reference in New Issue
Block a user