Fix paligemma detection inference (#31587)

* fix extended attention mask

* add slow test for detection instance

* [run-slow]paligemma
This commit is contained in:
Pablo Montalvo 2024-06-26 19:17:09 +02:00 committed by GitHub
parent e71f2863d7
commit 492ee17ec3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 3 deletions

View File

@ -448,13 +448,11 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
# Get the target length
target_seqlen = cache_position[-1] + 1
extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1] + 1),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses PaliGemma+ Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
@ -467,6 +465,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
attention_mask = attention_mask.to(inputs_embeds.dtype)
outputs = self.language_model(
attention_mask=attention_mask,

View File

@ -430,6 +430,32 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_torch
@require_read_token
def test_integration_detection_bug(self):
# this is a reproducer of https://github.com/huggingface/transformers/issues/31425 where not enough context
# impacted negatively segmentation generations.
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, revision="bfloat16", torch_dtype=torch.bfloat16
).to(torch_device)
prompt = ("detect shoe",)
image = Image.open(
requests.get(
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/shoe.png",
stream=True,
).raw
)
inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe" # fmt: skip
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_read_token
def test_paligemma_index_error_bug(self):