mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Patch GPTNeoX to use adequate FA2 if position_ids is provided (#35318)
This commit is contained in:
parent
5e7aedebeb
commit
2bb60982ac
@ -148,6 +148,7 @@ def flash_attention_forward(
|
|||||||
norm_factor,
|
norm_factor,
|
||||||
attention_dropout,
|
attention_dropout,
|
||||||
training,
|
training,
|
||||||
|
position_ids=None,
|
||||||
target_dtype=None,
|
target_dtype=None,
|
||||||
**_kwargs,
|
**_kwargs,
|
||||||
):
|
):
|
||||||
@ -173,6 +174,7 @@ def flash_attention_forward(
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
query_length,
|
query_length,
|
||||||
dropout=attention_dropout,
|
dropout=attention_dropout,
|
||||||
|
position_ids=position_ids,
|
||||||
softmax_scale=norm_factor,
|
softmax_scale=norm_factor,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
use_top_left_mask=flash_attn_uses_top_left_mask,
|
use_top_left_mask=flash_attn_uses_top_left_mask,
|
||||||
@ -353,6 +355,7 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
norm_factor=self.norm_factor,
|
norm_factor=self.norm_factor,
|
||||||
attention_dropout=self.config.attention_dropout,
|
attention_dropout=self.config.attention_dropout,
|
||||||
|
Loading…
Reference in New Issue
Block a user