mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix mask slicing for models with HybridCache (#35681)
* correctly slice * check mask * Update modular_gemma2.py * fix * add tests * fix typo * finally fix mask slicing * Finally correctly slice in all cases!! * add test for all attention functions * small fix in tests * trick around dynamo tracing issue * last update * more robust * kwargs propagation * make it explicit for checkpointing * apply modular
This commit is contained in:
parent
b764c20b09
commit
3f860dba55
@ -260,6 +260,11 @@ class Cohere2Attention(nn.Module):
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -323,6 +328,7 @@ class Cohere2DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@ -343,21 +349,30 @@ class Cohere2DecoderLayer(nn.Module):
|
||||
(see `past_key_values`).
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
|
||||
"""
|
||||
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# Flash-attn is a 2D tensor
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if past_key_value is not None: # when decoding
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
if attention_mask.shape[-1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@ -557,6 +572,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -595,9 +611,20 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
@ -621,6 +648,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -631,6 +659,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
@ -917,6 +946,10 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
|
@ -305,6 +305,11 @@ class Cohere2Attention(CohereAttention, nn.Module):
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -349,6 +354,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@ -369,21 +375,30 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
|
||||
(see `past_key_values`).
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
|
||||
"""
|
||||
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# Flash-attn is a 2D tensor
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if past_key_value is not None: # when decoding
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
if attention_mask.shape[-1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@ -443,6 +458,7 @@ class Cohere2Model(Gemma2Model):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -481,9 +497,20 @@ class Cohere2Model(Gemma2Model):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
@ -507,6 +534,7 @@ class Cohere2Model(Gemma2Model):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -517,6 +545,7 @@ class Cohere2Model(Gemma2Model):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
@ -586,6 +615,10 @@ class Cohere2ForCausalLM(CohereForCausalLM):
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
|
@ -221,9 +221,19 @@ class Gemma2Attention(nn.Module):
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"cache_position": cache_position,
|
||||
"sliding_window": self.sliding_window,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -277,20 +287,30 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# Flash-attn is a 2D tensor
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if past_key_value is not None: # when decoding
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
if attention_mask.shape[-1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@ -306,6 +326,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -554,6 +575,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -593,6 +615,16 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
@ -628,6 +660,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -639,6 +672,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
@ -857,6 +891,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -926,6 +961,10 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
|
@ -265,9 +265,19 @@ class Gemma2Attention(GemmaAttention):
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"cache_position": cache_position,
|
||||
"sliding_window": self.sliding_window,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Here we need to slice as we use a static cache by default, but FA2 does not support it
|
||||
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
|
||||
seq_len = attention_mask.shape[-1]
|
||||
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -321,20 +331,30 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: int = 0,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# Flash-attn is a 2D tensor
|
||||
# In prefill, we may be larger than sliding window
|
||||
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
|
||||
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
|
||||
# thus we must slice from the right (at most `effective_seq_len` elements)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if past_key_value is not None: # when decoding
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
attention_mask = attention_mask[:, -effective_seq_len:]
|
||||
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
|
||||
# from the left, with an offset if we are beyond the sliding window
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
|
||||
)
|
||||
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
||||
if attention_mask.shape[-1] <= 1: # when decoding
|
||||
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
|
||||
offset = last_cache_position - effective_seq_len
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@ -350,6 +370,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -387,6 +408,7 @@ class Gemma2Model(GemmaModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
last_cache_position: Optional[int] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -426,6 +448,16 @@ class Gemma2Model(GemmaModel):
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
if last_cache_position is None:
|
||||
last_cache_position = 0
|
||||
if attention_mask is not None:
|
||||
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
|
||||
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
|
||||
last_cache_position = (
|
||||
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
|
||||
)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
@ -461,6 +493,7 @@ class Gemma2Model(GemmaModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
last_cache_position,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -472,6 +505,7 @@ class Gemma2Model(GemmaModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
last_cache_position=last_cache_position,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
@ -589,6 +623,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -658,6 +693,10 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
|
||||
# (retrieving the same value from `cache_position` later on would crash dynamo)
|
||||
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
|
||||
|
||||
if (
|
||||
isinstance(past_key_values, HybridCache)
|
||||
and attention_mask.ndim == 2
|
||||
|
@ -324,3 +324,36 @@ class Cohere2IntegrationTest(unittest.TestCase):
|
||||
)
|
||||
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
|
||||
|
||||
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)])
|
||||
@require_read_token
|
||||
def test_generation_beyond_sliding_window(self, attn_implementation: str):
|
||||
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
|
||||
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
|
||||
Outputs for every attention functions should be coherent and identical.
|
||||
"""
|
||||
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
|
||||
EXPECTED_COMPLETIONS = [
|
||||
" the mountains, the lakes, the rivers, the waterfalls, the waterfalls, the waterfalls, the waterfalls",
|
||||
", green, yellow, orange, purple, pink, brown, black, white, grey, silver",
|
||||
]
|
||||
|
||||
input_text = [
|
||||
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||
"A list of colors: red, blue", # This will almost all be padding tokens
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
||||
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
# Make sure prefill is larger than sliding window
|
||||
input_size = inputs.input_ids.shape[-1]
|
||||
self.assertTrue(input_size > model.config.sliding_window)
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
|
||||
output_text = tokenizer.batch_decode(out)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||
|
@ -394,3 +394,36 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)])
|
||||
@require_read_token
|
||||
def test_generation_beyond_sliding_window(self, attn_implementation: str):
|
||||
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
|
||||
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
|
||||
Outputs for every attention functions should be coherent and identical.
|
||||
"""
|
||||
model_id = "google/gemma-2-2b"
|
||||
EXPECTED_COMPLETIONS = [
|
||||
" the people, the food, the culture, the history, the music, the art, the architecture",
|
||||
", green, yellow, orange, purple, pink, brown, black, white, gray, silver",
|
||||
]
|
||||
|
||||
input_text = [
|
||||
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||
"A list of colors: red, blue", # This will almost all be padding tokens
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
||||
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
# Make sure prefill is larger than sliding window
|
||||
input_size = inputs.input_ids.shape[-1]
|
||||
self.assertTrue(input_size > model.config.sliding_window)
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
|
||||
output_text = tokenizer.batch_decode(out)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||
|
Loading…
Reference in New Issue
Block a user