From 2ad152f84c1111a87adf39467aef3c6bdd51cd41 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Tue, 20 May 2025 05:47:59 -0700 Subject: [PATCH] In Llama4 fix wrongly inverted causal attention mask when using SDPA implementation (#38094) When preparing the causal attention mask at this point the mask comes in as a float tensor with min value as a masked value. It is not correct to convert it to bool and treat it as a bool mask as this inverts the mask. `torch.nn.functional.scaled_dot_product_attention` expects that a masked value is `False`. I suspect that the `sdpa` implementation variant may not have been thoroughly tested and that is why this error was not caught earlier. Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/llama4/modeling_llama4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 50a6dc1fc81..c4ef631e47b 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -730,7 +730,7 @@ class Llama4TextModel(Llama4PreTrainedModel): # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None: chunked_attention_mask = chunked_attention_mask.bool() - causal_mask = causal_mask.bool() + causal_mask = causal_mask != torch.finfo(dtype).min if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor,