Fix paligemma inverted mask (#31207)

* pass inverted causal mask

* add sanity check for paligemma finetuning

* [run-slow]paligemma
This commit is contained in:
Pablo Montalvo 2024-06-10 11:22:39 +02:00 committed by GitHub
parent 807483edba
commit 6b11f89c6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 0 deletions

View File

@ -348,8 +348,11 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels)
else:
causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
# invert causal mask
causal_mask = torch.where(causal_mask == 0, min_dtype, 0)
causal_mask = causal_mask.to(dtype).expand(-1, self.config.text_config.num_key_value_heads, -1, -1)
final_labels = None
return final_embedding, causal_mask, final_labels, position_ids
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)

View File

@ -454,3 +454,48 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)
@slow
@require_torch
@require_read_token
def test_paligemma_finetuning_with_suffixes_bf16(self):
# this is a supplementary test to ensure paligemma fine-tuning that relies on token_type_ids is robust to future changes
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, revision="bfloat16", torch_dtype=torch.bfloat16
).to(torch_device)
# The first batch is longer in terms of text, the second will be padded.
prompts = [
"answer en Where is the cow standing?",
"",
]
suffixes = ["beach", "cow standing on the beach"]
image1 = Image.open(
requests.get(
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
stream=True,
).raw
)
image2 = image1
inputs = (
self.processor(text=prompts, suffix=suffixes, images=[image1, image2], return_tensors="pt", padding=True)
.to(torch.bfloat16)
.to(torch_device)
)
expected_labels = torch.tensor(
[266 * [-100] + [54901, 1], 262 * [-100] + [14706, 9980, 611, 573, 8318, 1]]
).to(torch_device)
assert torch.equal(inputs["labels"], expected_labels)
expected_token_type_ids = torch.tensor([266 * [0] + 2 * [1], 262 * [0] + 6 * [1]]).to(torch_device)
assert torch.equal(inputs["token_type_ids"], expected_token_type_ids)
output = model(**inputs)
# check that loss does not error out
_ = output.loss