mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix paligemma detection inference (#31587)
* fix extended attention mask * add slow test for detection instance * [run-slow]paligemma
This commit is contained in:
parent
e71f2863d7
commit
492ee17ec3
@ -448,13 +448,11 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
|
||||
# Get the target length
|
||||
target_seqlen = cache_position[-1] + 1
|
||||
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
||||
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1] + 1),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses PaliGemma+ Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
@ -467,6 +465,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
|
||||
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
|
||||
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
|
@ -430,6 +430,32 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
|
||||
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_read_token
|
||||
def test_integration_detection_bug(self):
|
||||
# this is a reproducer of https://github.com/huggingface/transformers/issues/31425 where not enough context
|
||||
# impacted negatively segmentation generations.
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
model = PaliGemmaForConditionalGeneration.from_pretrained(
|
||||
model_id, revision="bfloat16", torch_dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
prompt = ("detect shoe",)
|
||||
|
||||
image = Image.open(
|
||||
requests.get(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/shoe.png",
|
||||
stream=True,
|
||||
).raw
|
||||
)
|
||||
|
||||
inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
|
||||
EXPECTED_DECODED_TEXT = "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe" # fmt: skip
|
||||
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_paligemma_index_error_bug(self):
|
||||
|
Loading…
Reference in New Issue
Block a user