[mistral] Fix FA2 attention reshape for Mistral Nemo (#32065)

* [mistral] Fix FA2 attention reshape

* [run-slow] mistral
This commit is contained in:
Joshua Lochner 2024-07-19 11:19:35 +02:00 committed by GitHub
parent cd48553fc8
commit 22f888b3fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -387,7 +387,7 @@ class MistralFlashAttention2(MistralAttention):
is_causal=self.is_causal,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions: