mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Patch GPTNeoX to use adequate FA2 if position_ids is provided (#35318)
This commit is contained in:
parent
5e7aedebeb
commit
2bb60982ac
@ -148,6 +148,7 @@ def flash_attention_forward(
|
||||
norm_factor,
|
||||
attention_dropout,
|
||||
training,
|
||||
position_ids=None,
|
||||
target_dtype=None,
|
||||
**_kwargs,
|
||||
):
|
||||
@ -173,6 +174,7 @@ def flash_attention_forward(
|
||||
attention_mask,
|
||||
query_length,
|
||||
dropout=attention_dropout,
|
||||
position_ids=position_ids,
|
||||
softmax_scale=norm_factor,
|
||||
is_causal=True,
|
||||
use_top_left_mask=flash_attn_uses_top_left_mask,
|
||||
@ -353,6 +355,7 @@ class GPTNeoXAttention(nn.Module):
|
||||
key,
|
||||
value,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
norm_factor=self.norm_factor,
|
||||
attention_dropout=self.config.attention_dropout,
|
||||
|
Loading…
Reference in New Issue
Block a user