mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix Blip-2 CI again (#21637)
* fix blip-2 ci * fix blip-2 ci --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
762dda44de
commit
e3d832ff87
@ -799,11 +799,13 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_inference_opt_batched_beam_search(self):
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device)
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
# prepare image
|
||||
image = prepare_img()
|
||||
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
|
||||
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs, num_beams=2)
|
||||
|
||||
@ -844,14 +846,16 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_inference_t5_batched_beam_search(self):
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
||||
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device)
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/blip2-flan-t5-xl", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
# prepare image
|
||||
image = prepare_img()
|
||||
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device)
|
||||
inputs = processor(images=[image, image], return_tensors="pt").to(torch_device, dtype=torch.float16)
|
||||
|
||||
predictions = model.generate(**inputs, num_beams=2)
|
||||
|
||||
# Test output (in this case, slightly different from greedy search)
|
||||
self.assertEqual(predictions[0].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1])
|
||||
self.assertEqual(predictions[1].tolist(), [0, 3, 9, 2335, 19, 3823, 30, 8, 2608, 28, 160, 1782, 1])
|
||||
self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
|
||||
self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
|
||||
|
Loading…
Reference in New Issue
Block a user