diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index c393c9cf88d..d9716272984 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -21,6 +21,7 @@ from parameterized import parameterized from transformers import ChameleonConfig, is_torch_available, is_vision_available, set_seed from transformers.testing_utils import ( + Expectations, require_bitsandbytes, require_read_token, require_torch, @@ -417,7 +418,14 @@ class ChameleonIntegrationTest(unittest.TestCase): inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.float16) # greedy generation outputs - EXPECTED_TEXT_COMPLETION = ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Alpha Centauri. The star map is a representation of the night sky, showing the positions of stars in'] # fmt: skip + EXPECTED_TEXT_COMPLETIONS = Expectations( + { + ("cuda", 7): ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Alpha Centauri. The star map is a representation of the night sky, showing the positions of stars in'], + ("cuda", 8): ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot representing the position of the star Alpha Centauri. Alpha Centauri is the brightest star in the constellation Centaurus and is located'], + } + ) # fmt: skip + EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation() + generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) text = processor.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) @@ -447,10 +455,20 @@ class ChameleonIntegrationTest(unittest.TestCase): ) # greedy generation outputs - EXPECTED_TEXT_COMPLETION = [ - 'Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Alpha Centauri. The star map is a representation of the night sky, showing the positions of stars in', - 'What constellation is this image showing?The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.' - ] # fmt: skip + EXPECTED_TEXT_COMPLETIONS = Expectations( + { + ("cuda", 7): [ + 'Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue line extending across the center of the image. The line is labeled "390 light years" and is accompanied by a small black and', + 'What constellation is this image showing?The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.', + ], + ("cuda", 8): [ + 'Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Alpha Centauri. The star map is a representation of the night sky, showing the positions of stars in', + 'What constellation is this image showing?The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.', + ], + } + ) # fmt: skip + EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation() + generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) text = processor.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text)