mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
[mistral] Fix FA2 attention reshape for Mistral Nemo (#32065)
* [mistral] Fix FA2 attention reshape * [run-slow] mistral
This commit is contained in:
parent
cd48553fc8
commit
22f888b3fa
@ -387,7 +387,7 @@ class MistralFlashAttention2(MistralAttention):
|
||||
is_causal=self.is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
|
Loading…
Reference in New Issue
Block a user