Fix PaliGemma Pad Token Masking During Training #35855 (#35859)

* change order of unmasking of tokens

* library import

* class setup

* test function

* refactor

* add commit message

* test modified

* explict initiliasation of weights + made model smaller

* removed sepete testing file

* fixup

* fixup core

* test attention mask with token types

* tests fixup

* removed PaliGemmaAttentionMaskTest class

---------

Co-authored-by: sambhavnoobcoder <indosambahv@gmail.com>
This commit is contained in:
Sambhav Dixit 2025-02-13 14:41:44 +05:30 committed by GitHub
parent 1614d196e8
commit 950cfb0b4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 5 deletions

View File

@ -383,16 +383,20 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
# First unmask prefix tokens during training
if is_training:
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
)
# Then apply padding mask (will mask pad tokens)
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
if is_training:
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
)
return causal_mask
def get_image_features(self, pixel_values: torch.FloatTensor):

View File

@ -351,6 +351,47 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_generate_compile_model_forward(self):
pass
def test_attention_mask_with_token_types(self):
"""Test that attention masking works correctly both with and without token type IDs."""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
# Case 1: With token_type_ids
outputs_with_types = model(
**inputs_dict,
output_attentions=True,
)
# Case 2: Without token_type_ids
inputs_no_types = {k: v for k, v in inputs_dict.items() if k != "token_type_ids"}
outputs_no_types = model(
**inputs_no_types,
output_attentions=True,
)
attention_outputs_with_types = outputs_with_types.attentions
attention_outputs_no_types = outputs_no_types.attentions
# Verify pad tokens remain masked in both cases
attention_mask = inputs_dict["attention_mask"]
pad_positions = attention_mask == 0
for layer_attentions in [attention_outputs_with_types, attention_outputs_no_types]:
for layer_attn in layer_attentions:
# Check if pad tokens are properly masked
for batch_idx in range(layer_attn.shape[0]):
for seq_idx in range(layer_attn.shape[-1]):
if pad_positions[batch_idx, seq_idx]:
# Verify attention weights for pad tokens are zero
self.assertTrue(
torch.all(layer_attn[batch_idx, :, :, seq_idx] == 0),
f"Found non-zero attention weights for padding token at batch {batch_idx}, sequence position {seq_idx}",
)
@slow
@require_torch