From e3d832ff87c6ec997125deaa4f1b239db8f9e613 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 15 Feb 2023 10:59:42 +0100 Subject: [PATCH] Fix Blip-2 CI again (#21637) * fix blip-2 ci * fix blip-2 ci --------- Co-authored-by: ydshieh --- tests/models/blip_2/test_modeling_blip_2.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 93d63104c53..c888eb08014 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -799,11 +799,13 @@ class Blip2ModelIntegrationTest(unittest.TestCase): def test_inference_opt_batched_beam_search(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") - model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device) + model = Blip2ForConditionalGeneration.from_pretrained( + "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16 + ).to(torch_device) # prepare image image = prepare_img() - inputs = processor(images=[image, image], return_tensors="pt").to(torch_device) + inputs = processor(images=[image, image], return_tensors="pt").to(torch_device, dtype=torch.float16) predictions = model.generate(**inputs, num_beams=2) @@ -844,14 +846,16 @@ class Blip2ModelIntegrationTest(unittest.TestCase): def test_inference_t5_batched_beam_search(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") - model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device) + model = Blip2ForConditionalGeneration.from_pretrained( + "Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16 + ).to(torch_device) # prepare image image = prepare_img() - inputs = processor(images=[image, image], return_tensors="pt").to(torch_device) + inputs = processor(images=[image, image], return_tensors="pt").to(torch_device, dtype=torch.float16) predictions = model.generate(**inputs, num_beams=2) # Test output (in this case, slightly different from greedy search) - self.assertEqual(predictions[0].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1]) - self.assertEqual(predictions[1].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1]) + self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1]) + self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])