Patch GPTNeoX to use adequate FA2 if position_ids is provided (#35318)

This commit is contained in:
Taha Yassine 2024-12-23 13:45:55 +01:00 committed by GitHub
parent 5e7aedebeb
commit 2bb60982ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -148,6 +148,7 @@ def flash_attention_forward(
norm_factor,
attention_dropout,
training,
position_ids=None,
target_dtype=None,
**_kwargs,
):
@ -173,6 +174,7 @@ def flash_attention_forward(
attention_mask,
query_length,
dropout=attention_dropout,
position_ids=position_ids,
softmax_scale=norm_factor,
is_causal=True,
use_top_left_mask=flash_attn_uses_top_left_mask,
@ -353,6 +355,7 @@ class GPTNeoXAttention(nn.Module):
key,
value,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
norm_factor=self.norm_factor,
attention_dropout=self.config.attention_dropout,