From de4cf5a38e9678b9e465867a8a6b88ea727bea52 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 2 Jun 2025 22:46:35 +0200 Subject: [PATCH] Fix blip2 tests (#38510) * fix 1: not sure * fix 2: _supports_flex_attn = False * fix 3: embedding_output = self.layernorm(query_embeds.to(self.layernorm.weight.dtype)) * fix 4: query_embeds = query_embeds.to(self.layernorm.weight.dtype) * fix 5: text_embeds = text_embeds.to(dtype=torch.float16) * fix 5: question_embeds.to(dtype=torch.float16) * fix 6: text_embeds = text_embeds.to(dtype=self.itm_head.weight.dtype) * fix 7: image_embeds and question_embeds * fix 8: fix other 2 fp16 tests * fix 9: fix T5 OOM * fix 10: fix T5 OOM * fix 11: fix T5 * fix 11: fix T5 beam * fix 12: _supports_sdpa=False * fix 12: style and expect * revert * revert --------- Co-authored-by: ydshieh --- .../models/blip_2/modeling_blip_2.py | 7 +++ tests/models/blip_2/test_modeling_blip_2.py | 50 ++++++++++++++++--- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index ea591bf730d..5945f4f48ce 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1196,6 +1196,8 @@ class Blip2QFormerModel(Blip2PreTrainedModel): query_length if query_length is not None else query_embeds.shape[1] if query_embeds is not None else 0 ) + # `Blip2QFormerModel` is kept as fp32 + query_embeds = query_embeds.to(self.layernorm.weight.dtype) embedding_output = self.layernorm(query_embeds) embedding_output = self.dropout(embedding_output) @@ -1737,6 +1739,7 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel): ) pooled_output = text_outputs[0] if not return_dict else text_outputs.last_hidden_state + pooled_output = pooled_output.to(dtype=self.text_projection.weight.dtype) text_embeds = self.text_projection(pooled_output) text_embeds = nn.functional.normalize(text_embeds, dim=-1) @@ -1837,6 +1840,7 @@ class Blip2VisionModelWithProjection(Blip2PreTrainedModel): ) embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state + embeds = embeds.to(dtype=self.vision_projection.weight.dtype) image_embeds = self.vision_projection(embeds) image_embeds = nn.functional.normalize(image_embeds, dim=-1) @@ -2395,6 +2399,7 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): return_dict=return_dict, ) text_embeds = text_outputs[0] if not return_dict else text_outputs.last_hidden_state + text_embeds = text_embeds.to(dtype=self.itm_head.weight.dtype) output = self.itm_head(text_embeds[:, : query_tokens.size(1), :]) logits_per_image = output.mean(dim=1) @@ -2408,6 +2413,7 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): return_dict=return_dict, ) image_embeds = query_outputs[0] if not return_dict else query_outputs.last_hidden_state + image_embeds = image_embeds.to(dtype=self.vision_projection.weight.dtype) query_embeds = self.embeddings( input_ids=input_ids, @@ -2419,6 +2425,7 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): return_dict=return_dict, ) question_embeds = text_outputs[0] if not return_dict else text_outputs.last_hidden_state + question_embeds = question_embeds.to(dtype=self.text_projection.weight.dtype) # normalized features image_embeds = nn.functional.normalize(self.vision_projection(image_embeds), dim=-1) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 89579ae833a..38b3714d7c8 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -24,6 +24,8 @@ from parameterized import parameterized from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig from transformers.testing_utils import ( + Expectations, + cleanup, require_torch, require_torch_accelerator, require_torch_fp16, @@ -1620,6 +1622,12 @@ def prepare_img(): @require_torch @slow class Blip2ModelIntegrationTest(unittest.TestCase): + def setUp(self): + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + def test_inference_opt(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") model = Blip2ForConditionalGeneration.from_pretrained( @@ -1698,9 +1706,19 @@ class Blip2ModelIntegrationTest(unittest.TestCase): predictions = model.generate(**inputs) generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() + expectations = Expectations( + { + ("cuda", 7): [ + [0, 3, 9, 2335, 19, 1556, 28, 160, 1782, 30, 8, 2608, 1], + "a woman is playing with her dog on the beach", + ] + } + ) + expected_outputs = expectations.get_expectation() + # Test output - self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1]) - self.assertEqual("woman playing with dog on the beach", generated_text) + self.assertEqual(predictions[0].tolist(), expected_outputs[0]) + self.assertEqual(expected_outputs[1], generated_text) # image and context prompt = "Question: which city is this? Answer:" @@ -1709,9 +1727,19 @@ class Blip2ModelIntegrationTest(unittest.TestCase): predictions = model.generate(**inputs) generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() + expectations = Expectations( + { + ("cuda", 7): [ + [0, 3, 7, 152, 2515, 11389, 3523, 1], + "san francisco", + ] + } + ) + expected_outputs = expectations.get_expectation() + # Test output - self.assertEqual(predictions[0].tolist(), [0, 3, 7, 152, 67, 839, 1]) - self.assertEqual(generated_text, "san diego") + self.assertEqual(predictions[0].tolist(), expected_outputs[0]) + self.assertEqual(generated_text, expected_outputs[1]) def test_inference_t5_batched_beam_search(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") @@ -1725,9 +1753,19 @@ class Blip2ModelIntegrationTest(unittest.TestCase): predictions = model.generate(**inputs, num_beams=2) + expectations = Expectations( + { + ("cuda", 7): [ + [0, 3, 9, 2335, 19, 1556, 28, 160, 1782, 30, 8, 2608, 1], + [0, 3, 9, 2335, 19, 1556, 28, 160, 1782, 30, 8, 2608, 1], + ] + } + ) + expected_predictions = expectations.get_expectation() + # Test output (in this case, slightly different from greedy search) - self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1]) - self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1]) + self.assertEqual(predictions[0].tolist(), expected_predictions[0]) + self.assertEqual(predictions[1].tolist(), expected_predictions[1]) @require_torch_multi_accelerator def test_inference_opt_multi_accelerator(self):