mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
parent
c4e71e8fff
commit
b5ececb900
@ -782,7 +782,7 @@ class Gemma3MultiModalProjector(nn.Module):
|
||||
return projected_vision_outputs.type_as(vision_outputs)
|
||||
|
||||
|
||||
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
|
||||
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
|
||||
"""
|
||||
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
||||
not start and end indices.
|
||||
@ -792,8 +792,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti
|
||||
return None
|
||||
|
||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
# If it's 1, we need to unmask it
|
||||
return token_type_ids[batch_idx, kv_idx] == 1
|
||||
# If the difference is less than image size, both are part of the same image block
|
||||
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
|
||||
# If it's 1 for both query and key/value, we are in an image block
|
||||
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)
|
||||
|
||||
# This is bidirectional attention whenever we are dealing with image tokens
|
||||
return is_image_block & same_image_block
|
||||
|
||||
return inner_mask
|
||||
|
||||
@ -945,7 +950,7 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
||||
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
|
||||
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
||||
token_type_ids.to(cache_position.device)
|
||||
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
|
||||
)
|
||||
|
||||
# Create the masks
|
||||
@ -1211,7 +1216,9 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
||||
# Add the token type ids mask for generate as well
|
||||
if token_type_ids is not None and input_embeds.shape[1] != 1:
|
||||
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
|
||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
||||
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
|
||||
)
|
||||
|
||||
return create_masks_for_generate(**mask_kwargs)
|
||||
|
||||
|
@ -722,7 +722,7 @@ class Gemma3MultiModalProjector(nn.Module):
|
||||
return projected_vision_outputs.type_as(vision_outputs)
|
||||
|
||||
|
||||
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
|
||||
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
|
||||
"""
|
||||
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
||||
not start and end indices.
|
||||
@ -732,8 +732,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti
|
||||
return None
|
||||
|
||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
# If it's 1, we need to unmask it
|
||||
return token_type_ids[batch_idx, kv_idx] == 1
|
||||
# If the difference is less than image size, both are part of the same image block
|
||||
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
|
||||
# If it's 1 for both query and key/value, we are in an image block
|
||||
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)
|
||||
|
||||
# This is bidirectional attention whenever we are dealing with image tokens
|
||||
return is_image_block & same_image_block
|
||||
|
||||
return inner_mask
|
||||
|
||||
@ -836,7 +841,7 @@ class Gemma3Model(PaliGemmaModel):
|
||||
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
|
||||
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
||||
token_type_ids.to(cache_position.device)
|
||||
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
|
||||
)
|
||||
|
||||
# Create the masks
|
||||
@ -1055,7 +1060,9 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
||||
# Add the token type ids mask for generate as well
|
||||
if token_type_ids is not None and input_embeds.shape[1] != 1:
|
||||
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
|
||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
|
||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
||||
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
|
||||
)
|
||||
|
||||
return create_masks_for_generate(**mask_kwargs)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user