In Llama4 fix wrongly inverted causal attention mask when using SDPA implementation (#38094)

When preparing the causal attention mask at this point the mask comes
in as a float tensor with min value as a masked value.
It is not correct to convert it to bool and treat it as a bool mask as
this inverts the mask.
`torch.nn.functional.scaled_dot_product_attention` expects that a masked value is `False`.

I suspect that the `sdpa` implementation variant may not have been
thoroughly tested and that is why this error was not caught earlier.

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Boian Petkantchin 2025-05-20 05:47:59 -07:00 committed by GitHub
parent de70c8426e
commit 2ad152f84c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -730,7 +730,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None: if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None:
chunked_attention_mask = chunked_attention_mask.bool() chunked_attention_mask = chunked_attention_mask.bool()
causal_mask = causal_mask.bool() causal_mask = causal_mask != torch.finfo(dtype).min
if AttentionMaskConverter._ignore_causal_mask_sdpa( if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, attention_mask,
inputs_embeds=input_tensor, inputs_embeds=input_tensor,