Fix image token mask in Gemma3 (#38295)

fix mask
This commit is contained in:
Cyril Vallez 2025-05-27 11:15:52 +02:00 committed by GitHub
parent c4e71e8fff
commit b5ececb900
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 10 deletions

View File

@ -782,7 +782,7 @@ class Gemma3MultiModalProjector(nn.Module):
return projected_vision_outputs.type_as(vision_outputs)
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
not start and end indices.
@ -792,8 +792,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti
return None
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
# If it's 1, we need to unmask it
return token_type_ids[batch_idx, kv_idx] == 1
# If the difference is less than image size, both are part of the same image block
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
# If it's 1 for both query and key/value, we are in an image block
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)
# This is bidirectional attention whenever we are dealing with image tokens
return is_image_block & same_image_block
return inner_mask
@ -945,7 +950,7 @@ class Gemma3Model(Gemma3PreTrainedModel):
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device)
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
)
# Create the masks
@ -1211,7 +1216,9 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
# Add the token type ids mask for generate as well
if token_type_ids is not None and input_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
)
return create_masks_for_generate(**mask_kwargs)

View File

@ -722,7 +722,7 @@ class Gemma3MultiModalProjector(nn.Module):
return projected_vision_outputs.type_as(vision_outputs)
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
not start and end indices.
@ -732,8 +732,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti
return None
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
# If it's 1, we need to unmask it
return token_type_ids[batch_idx, kv_idx] == 1
# If the difference is less than image size, both are part of the same image block
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
# If it's 1 for both query and key/value, we are in an image block
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)
# This is bidirectional attention whenever we are dealing with image tokens
return is_image_block & same_image_block
return inner_mask
@ -836,7 +841,7 @@ class Gemma3Model(PaliGemmaModel):
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device)
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
)
# Create the masks
@ -1055,7 +1060,9 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
# Add the token type ids mask for generate as well
if token_type_ids is not None and input_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
)
return create_masks_for_generate(**mask_kwargs)