From b0735dc0c13a07f317c40b05581e36d21b306368 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 27 May 2025 11:31:56 +0200 Subject: [PATCH] [paligemma] fix processor with suffix (#38365) fix pg processor --- .../models/paligemma/processing_paligemma.py | 3 ++- tests/models/paligemma/test_processor_paligemma.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index c18a698c237..a630c4720ed 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -315,7 +315,8 @@ class PaliGemmaProcessor(ProcessorMixin): return_data = {**inputs, "pixel_values": pixel_values} if return_token_type_ids: - labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100) + labels = np.array(inputs["input_ids"]) + labels[np.array(inputs["token_type_ids"]) == 0] = -100 return_data.update({"labels": labels}) if return_mm_token_type_ids: diff --git a/tests/models/paligemma/test_processor_paligemma.py b/tests/models/paligemma/test_processor_paligemma.py index 8ccae458875..56e74928925 100644 --- a/tests/models/paligemma/test_processor_paligemma.py +++ b/tests/models/paligemma/test_processor_paligemma.py @@ -62,6 +62,20 @@ class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase): ) self.assertEqual(len(inputs["input_ids"][0]), 112) + @require_torch + def test_call_with_suffix(self): + input_str = "lower newer" + suffix = "upper older longer string" + image_input = self.prepare_image_inputs() + processor = self.get_processor() + inputs = processor(text=input_str, images=image_input, suffix=suffix) + self.assertTrue("labels" in inputs) + self.assertEqual(len(inputs["labels"][0]), len(inputs["input_ids"][0])) + + inputs = processor(text=input_str, images=image_input, suffix=suffix, return_tensors="pt") + self.assertTrue("labels" in inputs) + self.assertEqual(len(inputs["labels"][0]), len(inputs["input_ids"][0])) + def test_text_with_image_tokens(self): image_processor = self.get_component("image_processor") tokenizer = self.get_component("tokenizer")