mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
cache support
This commit is contained in:
parent
6afc75bf36
commit
1eaca54b02
@ -25,6 +25,7 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, EncoderDecoderCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
@ -243,7 +244,7 @@ def eager_attn_forward(
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None, is_causal=False):
|
||||
def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -271,6 +272,7 @@ class BertSelfAttention(nn.Module):
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
self.is_causal = is_causal
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -281,6 +283,7 @@ class BertSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False, # Only kept to be BC
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
@ -301,25 +304,34 @@ class BertSelfAttention(nn.Module):
|
||||
# get query proj
|
||||
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
|
||||
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
|
||||
key_layer, value_layer = past_key_value
|
||||
if past_key_value is not None:
|
||||
if isinstance(past_key_value, EncoderDecoderCache):
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_layer from cache
|
||||
curr_past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value.self_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value
|
||||
|
||||
if is_cross_attention and past_key_value is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = curr_past_key_value.key_cache[self.layer_idx]
|
||||
value_layer = curr_past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_layer = self.key(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_layer = self.value(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
if past_key_value is not None and not is_cross_attention:
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
if past_key_value is not None:
|
||||
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_layer, value_layer = curr_past_key_value.update(
|
||||
key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
if is_cross_attention:
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
|
||||
attention_interface: Callable = eager_attn_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -369,9 +381,9 @@ class BertSelfOutput(nn.Module):
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None, is_causal=False):
|
||||
def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
|
||||
super().__init__()
|
||||
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type, is_causal=is_causal)
|
||||
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type, is_causal=is_causal, layer_idx=layer_idx)
|
||||
self.output = BertSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -402,6 +414,7 @@ class BertAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
@ -411,6 +424,7 @@ class BertAttention(nn.Module):
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -447,17 +461,17 @@ class BertOutput(nn.Module):
|
||||
|
||||
|
||||
class BertLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = BertAttention(config, is_causal=config.is_decoder)
|
||||
self.attention = BertAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = BertAttention(config, position_embedding_type="absolute", is_causal=False)
|
||||
self.crossattention = BertAttention(config, position_embedding_type="absolute", is_causal=False, layer_idx=layer_idx)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
@ -470,26 +484,24 @@ class BertLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
past_key_value=past_key_value,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
@ -497,24 +509,19 @@ class BertLayer(GradientCheckpointingLayer):
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
@ -522,7 +529,7 @@ class BertLayer(GradientCheckpointingLayer):
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
outputs = outputs + (past_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -536,7 +543,7 @@ class BertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([BertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
@ -551,6 +558,7 @@ class BertEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
@ -563,13 +571,22 @@ class BertEncoder(nn.Module):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
logger.warning_once(
|
||||
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
|
||||
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
|
||||
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
|
||||
)
|
||||
return_legacy_cache = True
|
||||
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
|
||||
|
||||
next_decoder_cache = None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
@ -577,13 +594,14 @@ class BertEncoder(nn.Module):
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
next_decoder_cache = layer_outputs[-1]
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
@ -592,12 +610,16 @@ class BertEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = past_key_values.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
@ -606,7 +628,7 @@ class BertEncoder(nn.Module):
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
@ -817,6 +839,7 @@ class BertModel(BertPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -842,8 +865,13 @@ class BertModel(BertPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
past_key_values_length = 0
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = (
|
||||
past_key_values[0][0].shape[-2]
|
||||
if not isinstance(past_key_values, Cache)
|
||||
else past_key_values.get_seq_length()
|
||||
)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
@ -928,6 +956,7 @@ class BertModel(BertPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
@ -1094,6 +1123,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**loss_kwargs,
|
||||
) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
@ -1120,6 +1150,7 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
@ -743,9 +743,11 @@ class BertModelIntegrationTest(unittest.TestCase):
|
||||
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
|
||||
)
|
||||
|
||||
# Case where query length != kv_length.
|
||||
res_eager = model(**inp, past_key_values=pkv)
|
||||
res_sdpa = model_sdpa(**inp, past_key_values=pkv)
|
||||
# Case where query length != kv_length. Note that model needs to be a decoder so we can use cache
|
||||
model.config.is_decoder = True
|
||||
model_sdpa.config.is_decoder = True
|
||||
res_eager = model(**inp, past_key_values=pkv, use_cache=True)
|
||||
res_sdpa = model_sdpa(**inp, past_key_values=pkv, use_cache=True)
|
||||
self.assertTrue(
|
||||
torch.allclose(res_eager.last_hidden_state, res_sdpa.last_hidden_state, atol=1e-5, rtol=1e-4)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user