diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index e8303a79848..7839f4f56af 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -348,8 +348,11 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel): final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels) else: causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1) + # invert causal mask + causal_mask = torch.where(causal_mask == 0, min_dtype, 0) causal_mask = causal_mask.to(dtype).expand(-1, self.config.text_config.num_key_value_heads, -1, -1) final_labels = None + return final_embedding, causal_mask, final_labels, position_ids @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 10fd48060a9..935ceaf72d7 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -454,3 +454,48 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase): # Make sure that `generate` works _ = model.generate(**inputs, max_new_tokens=20) + + @slow + @require_torch + @require_read_token + def test_paligemma_finetuning_with_suffixes_bf16(self): + # this is a supplementary test to ensure paligemma fine-tuning that relies on token_type_ids is robust to future changes + model_id = "google/paligemma-3b-pt-224" + model = PaliGemmaForConditionalGeneration.from_pretrained( + model_id, revision="bfloat16", torch_dtype=torch.bfloat16 + ).to(torch_device) + # The first batch is longer in terms of text, the second will be padded. + prompts = [ + "answer en Where is the cow standing?", + "", + ] + + suffixes = ["beach", "cow standing on the beach"] + image1 = Image.open( + requests.get( + "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png", + stream=True, + ).raw + ) + image2 = image1 + + inputs = ( + self.processor(text=prompts, suffix=suffixes, images=[image1, image2], return_tensors="pt", padding=True) + .to(torch.bfloat16) + .to(torch_device) + ) + + expected_labels = torch.tensor( + [266 * [-100] + [54901, 1], 262 * [-100] + [14706, 9980, 611, 573, 8318, 1]] + ).to(torch_device) + + assert torch.equal(inputs["labels"], expected_labels) + + expected_token_type_ids = torch.tensor([266 * [0] + 2 * [1], 262 * [0] + 6 * [1]]).to(torch_device) + + assert torch.equal(inputs["token_type_ids"], expected_token_type_ids) + + output = model(**inputs) + + # check that loss does not error out + _ = output.loss