diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 92804930872..08740173009 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -782,7 +782,7 @@ class Gemma3MultiModalProjector(nn.Module): return projected_vision_outputs.type_as(vision_outputs) -def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]: +def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]: """ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, not start and end indices. @@ -792,8 +792,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti return None def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: - # If it's 1, we need to unmask it - return token_type_ids[batch_idx, kv_idx] == 1 + # If the difference is less than image size, both are part of the same image block + same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image + # If it's 1 for both query and key/value, we are in an image block + is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1) + + # This is bidirectional attention whenever we are dealing with image tokens + return is_image_block & same_image_block return inner_mask @@ -945,7 +950,7 @@ class Gemma3Model(Gemma3PreTrainedModel): if token_type_ids is not None and inputs_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` mask_kwargs["or_mask_function"] = token_type_ids_mask_function( - token_type_ids.to(cache_position.device) + token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image ) # Create the masks @@ -1211,7 +1216,9 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): # Add the token type ids mask for generate as well if token_type_ids is not None and input_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` - mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device)) + mask_kwargs["or_mask_function"] = token_type_ids_mask_function( + token_type_ids.to(cache_position.device), config.mm_tokens_per_image + ) return create_masks_for_generate(**mask_kwargs) diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 054a75630b1..d679d30c8b9 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -722,7 +722,7 @@ class Gemma3MultiModalProjector(nn.Module): return projected_vision_outputs.type_as(vision_outputs) -def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]: +def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]: """ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, not start and end indices. @@ -732,8 +732,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti return None def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: - # If it's 1, we need to unmask it - return token_type_ids[batch_idx, kv_idx] == 1 + # If the difference is less than image size, both are part of the same image block + same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image + # If it's 1 for both query and key/value, we are in an image block + is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1) + + # This is bidirectional attention whenever we are dealing with image tokens + return is_image_block & same_image_block return inner_mask @@ -836,7 +841,7 @@ class Gemma3Model(PaliGemmaModel): if token_type_ids is not None and inputs_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` mask_kwargs["or_mask_function"] = token_type_ids_mask_function( - token_type_ids.to(cache_position.device) + token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image ) # Create the masks @@ -1055,7 +1060,9 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): # Add the token type ids mask for generate as well if token_type_ids is not None and input_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` - mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device)) + mask_kwargs["or_mask_function"] = token_type_ids_mask_function( + token_type_ids.to(cache_position.device), config.mm_tokens_per_image + ) return create_masks_for_generate(**mask_kwargs)