mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-29 17:22:25 +06:00
Fix sliding window attention used in Gemma2FlashAttention2 (#32522)
* fix sliding window attention (flash2) in gemma2 model * [run-slow] gemma * fix slicing attention_mask for flash_attn2 * fix slicing attention_mask when flash_attn is used * add missing comment * slice the last seq_len tokens in the key, value states * revert code of slicing key, value states
This commit is contained in:
parent
8f2b6d5e3d
commit
342e3f9f20
@ -427,6 +427,7 @@ class Gemma2FlashAttention2(Gemma2Attention):
|
||||
dropout=dropout_rate,
|
||||
softmax_scale=self.scaling,
|
||||
is_causal=self.is_causal,
|
||||
sliding_window=self.sliding_window,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
|
||||
)
|
||||
@ -567,7 +568,8 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
||||
# Flash-attn is a 2D tensor
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
if past_key_value is not None: # when decoding
|
||||
attention_mask = attention_mask[:, -self.sliding_window :]
|
||||
else:
|
||||
min_dtype = torch.finfo(hidden_states.dtype).min
|
||||
sliding_window_mask = torch.tril(
|
||||
|
Loading…
Reference in New Issue
Block a user