Fix sliding window attention used in Gemma2FlashAttention2 (#32522)

* fix sliding window attention (flash2) in gemma2 model

* [run-slow] gemma

* fix slicing attention_mask for flash_attn2

* fix slicing attention_mask when flash_attn is used

* add missing comment

* slice the last seq_len tokens in the key, value states

* revert code of slicing key, value states
This commit is contained in:
Chaehong Jeong 2024-08-12 18:18:15 +09:00 committed by GitHub
parent 8f2b6d5e3d
commit 342e3f9f20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -427,6 +427,7 @@ class Gemma2FlashAttention2(Gemma2Attention):
dropout=dropout_rate,
softmax_scale=self.scaling,
is_causal=self.is_causal,
sliding_window=self.sliding_window,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
)
@ -567,7 +568,8 @@ class Gemma2DecoderLayer(nn.Module):
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask[:, -self.sliding_window :]
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(