diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 50a6dc1fc81..c4ef631e47b 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -730,7 +730,7 @@ class Llama4TextModel(Llama4PreTrainedModel): # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None: chunked_attention_mask = chunked_attention_mask.bool() - causal_mask = causal_mask.bool() + causal_mask = causal_mask != torch.finfo(dtype).min if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor,