diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index eca69a58fb3..c416179131c 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -214,7 +214,6 @@ class PixtralAttention(nn.Module): # Since we use packing, if flash_attention_2 is selected we rely on position_ids if self.config._attn_implementation == "flash_attention_2": kwargs["position_ids"] = kwargs["position_ids"].to(hidden_states.device, non_blocking=True) - attention_mask = None attn_output, attn_weights = attention_interface( self, @@ -508,9 +507,13 @@ class PixtralVisionModel(PixtralPreTrainedModel): position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids) - attention_mask = generate_block_attention_mask( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds - ) + if self.config._attn_implementation == "flash_attention_2": + # We only rely on position_ids when using flash_attention_2 + attention_mask = None + else: + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) return self.transformer( patch_embeds,