From 342e3f9f2067d45c27c30fbd4d748d233bca3adc Mon Sep 17 00:00:00 2001 From: Chaehong Jeong Date: Mon, 12 Aug 2024 18:18:15 +0900 Subject: [PATCH] Fix sliding window attention used in Gemma2FlashAttention2 (#32522) * fix sliding window attention (flash2) in gemma2 model * [run-slow] gemma * fix slicing attention_mask for flash_attn2 * fix slicing attention_mask when flash_attn is used * add missing comment * slice the last seq_len tokens in the key, value states * revert code of slicing key, value states --- src/transformers/models/gemma2/modeling_gemma2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e23bb876c3f..ee5af616ec2 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -427,6 +427,7 @@ class Gemma2FlashAttention2(Gemma2Attention): dropout=dropout_rate, softmax_scale=self.scaling, is_causal=self.is_causal, + sliding_window=self.sliding_window, use_top_left_mask=self._flash_attn_uses_top_left_mask, softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, ) @@ -567,7 +568,8 @@ class Gemma2DecoderLayer(nn.Module): if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding # Flash-attn is a 2D tensor if self.config._attn_implementation == "flash_attention_2": - attention_mask = attention_mask[:, -self.sliding_window :] + if past_key_value is not None: # when decoding + attention_mask = attention_mask[:, -self.sliding_window :] else: min_dtype = torch.finfo(hidden_states.dtype).min sliding_window_mask = torch.tril(