mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Fix paligemma inverted mask (#31207)
* pass inverted causal mask * add sanity check for paligemma finetuning * [run-slow]paligemma
This commit is contained in:
parent
807483edba
commit
6b11f89c6b
@ -348,8 +348,11 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels)
|
||||
else:
|
||||
causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
|
||||
# invert causal mask
|
||||
causal_mask = torch.where(causal_mask == 0, min_dtype, 0)
|
||||
causal_mask = causal_mask.to(dtype).expand(-1, self.config.text_config.num_key_value_heads, -1, -1)
|
||||
final_labels = None
|
||||
|
||||
return final_embedding, causal_mask, final_labels, position_ids
|
||||
|
||||
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
||||
|
@ -454,3 +454,48 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
# Make sure that `generate` works
|
||||
_ = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_read_token
|
||||
def test_paligemma_finetuning_with_suffixes_bf16(self):
|
||||
# this is a supplementary test to ensure paligemma fine-tuning that relies on token_type_ids is robust to future changes
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
||||
model_id, revision="bfloat16", torch_dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
# The first batch is longer in terms of text, the second will be padded.
|
||||
prompts = [
|
||||
"answer en Where is the cow standing?",
|
||||
"",
|
||||
]
|
||||
|
||||
suffixes = ["beach", "cow standing on the beach"]
|
||||
image1 = Image.open(
|
||||
requests.get(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
|
||||
stream=True,
|
||||
).raw
|
||||
)
|
||||
image2 = image1
|
||||
|
||||
inputs = (
|
||||
self.processor(text=prompts, suffix=suffixes, images=[image1, image2], return_tensors="pt", padding=True)
|
||||
.to(torch.bfloat16)
|
||||
.to(torch_device)
|
||||
)
|
||||
|
||||
expected_labels = torch.tensor(
|
||||
[266 * [-100] + [54901, 1], 262 * [-100] + [14706, 9980, 611, 573, 8318, 1]]
|
||||
).to(torch_device)
|
||||
|
||||
assert torch.equal(inputs["labels"], expected_labels)
|
||||
|
||||
expected_token_type_ids = torch.tensor([266 * [0] + 2 * [1], 262 * [0] + 6 * [1]]).to(torch_device)
|
||||
|
||||
assert torch.equal(inputs["token_type_ids"], expected_token_type_ids)
|
||||
|
||||
output = model(**inputs)
|
||||
|
||||
# check that loss does not error out
|
||||
_ = output.loss
|
||||
|
Loading…
Reference in New Issue
Block a user