fix patch_attention_mask incorrect setting which leads to the differe… (#33499)

* fix patch_attention_mask incorrect setting which leads to the difference in the generated text if batch > 1

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

* fix format

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

* [run_slow] idefics2

---------

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2024-09-18 05:24:42 +08:00 committed by GitHub
parent 6c051b4e1e
commit 454a0f2efd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 1 deletions

View File

@ -1388,7 +1388,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool()
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(

View File

@ -540,6 +540,41 @@ class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River,"
self.assertEqual(generated_texts[0], expected_generated_text)
@slow
@require_bitsandbytes
def test_integration_test_4bit_batch2(self):
# Let' s make sure we test the preprocessing to replace what is used
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base",
load_in_4bit=True,
)
from datasets import load_dataset
dataset = load_dataset("nielsr/docvqa_1200_examples", split="test")
text = [f"<image>{dataset[40]['query']['en']}", f"<image>{dataset[41]['query']['en']}"]
images = [[dataset[40]["image"]], [dataset[41]["image"]]]
inputs = self.processor(text=text, images=images, padding=True, return_tensors="pt")
generated_ids = model.generate(**inputs, max_new_tokens=64)
batched_generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
text = f"<image>{dataset[40]['query']['en']}"
images = dataset[40]["image"]
inputs = self.processor(text=text, images=images, padding=True, return_tensors="pt")
generated_ids = model.generate(**inputs, max_new_tokens=64)
generated_text_0 = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
text = f"<image>{dataset[41]['query']['en']}"
images = dataset[41]["image"]
inputs = self.processor(text=text, images=images, padding=True, return_tensors="pt")
generated_ids = model.generate(**inputs, max_new_tokens=64)
generated_text_1 = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(batched_generated_texts[0], generated_text_0[0])
self.assertEqual(batched_generated_texts[1], generated_text_1[0])
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes