Fix Blip-2 CI again (#21637)

* fix blip-2 ci

* fix blip-2 ci

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-02-15 10:59:42 +01:00 committed by GitHub
parent 762dda44de
commit e3d832ff87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -799,11 +799,13 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
def test_inference_opt_batched_beam_search(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device)
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
).to(torch_device)
# prepare image
image = prepare_img()
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device, dtype=torch.float16)
predictions = model.generate(**inputs, num_beams=2)
@ -844,14 +846,16 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
def test_inference_t5_batched_beam_search(self):
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device)
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16
).to(torch_device)
# prepare image
image = prepare_img()
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device, dtype=torch.float16)
predictions = model.generate(**inputs, num_beams=2)
# Test output (in this case, slightly different from greedy search)
self.assertEqual(predictions[0].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1])
self.assertEqual(predictions[1].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1])
self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])