[paligemma] fix processor with suffix (#38365)

fix pg processor
This commit is contained in:
Raushan Turganbay 2025-05-27 11:31:56 +02:00 committed by GitHub
parent 9e1017b479
commit b0735dc0c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 1 deletions

View File

@ -315,7 +315,8 @@ class PaliGemmaProcessor(ProcessorMixin):
return_data = {**inputs, "pixel_values": pixel_values} return_data = {**inputs, "pixel_values": pixel_values}
if return_token_type_ids: if return_token_type_ids:
labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100) labels = np.array(inputs["input_ids"])
labels[np.array(inputs["token_type_ids"]) == 0] = -100
return_data.update({"labels": labels}) return_data.update({"labels": labels})
if return_mm_token_type_ids: if return_mm_token_type_ids:

View File

@ -62,6 +62,20 @@ class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
) )
self.assertEqual(len(inputs["input_ids"][0]), 112) self.assertEqual(len(inputs["input_ids"][0]), 112)
@require_torch
def test_call_with_suffix(self):
input_str = "lower newer"
suffix = "upper older longer string"
image_input = self.prepare_image_inputs()
processor = self.get_processor()
inputs = processor(text=input_str, images=image_input, suffix=suffix)
self.assertTrue("labels" in inputs)
self.assertEqual(len(inputs["labels"][0]), len(inputs["input_ids"][0]))
inputs = processor(text=input_str, images=image_input, suffix=suffix, return_tensors="pt")
self.assertTrue("labels" in inputs)
self.assertEqual(len(inputs["labels"][0]), len(inputs["input_ids"][0]))
def test_text_with_image_tokens(self): def test_text_with_image_tokens(self):
image_processor = self.get_component("image_processor") image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer") tokenizer = self.get_component("tokenizer")