mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
parent
9e1017b479
commit
b0735dc0c1
@ -315,7 +315,8 @@ class PaliGemmaProcessor(ProcessorMixin):
|
||||
return_data = {**inputs, "pixel_values": pixel_values}
|
||||
|
||||
if return_token_type_ids:
|
||||
labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
|
||||
labels = np.array(inputs["input_ids"])
|
||||
labels[np.array(inputs["token_type_ids"]) == 0] = -100
|
||||
return_data.update({"labels": labels})
|
||||
|
||||
if return_mm_token_type_ids:
|
||||
|
@ -62,6 +62,20 @@ class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 112)
|
||||
|
||||
@require_torch
|
||||
def test_call_with_suffix(self):
|
||||
input_str = "lower newer"
|
||||
suffix = "upper older longer string"
|
||||
image_input = self.prepare_image_inputs()
|
||||
processor = self.get_processor()
|
||||
inputs = processor(text=input_str, images=image_input, suffix=suffix)
|
||||
self.assertTrue("labels" in inputs)
|
||||
self.assertEqual(len(inputs["labels"][0]), len(inputs["input_ids"][0]))
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, suffix=suffix, return_tensors="pt")
|
||||
self.assertTrue("labels" in inputs)
|
||||
self.assertEqual(len(inputs["labels"][0]), len(inputs["input_ids"][0]))
|
||||
|
||||
def test_text_with_image_tokens(self):
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
Loading…
Reference in New Issue
Block a user