diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py index 63c812180a9..fb3404d263d 100644 --- a/tests/models/aria/test_modeling_aria.py +++ b/tests/models/aria/test_modeling_aria.py @@ -420,8 +420,11 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): @require_vision @require_bitsandbytes def test_batched_generation(self): - model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True) - + # Skip multihead_attn for 4bit because MHA will read the original weight without dequantize. + # See https://github.com/huggingface/transformers/pull/37444#discussion_r2045852538. + model = AriaForConditionalGeneration.from_pretrained( + "rhymes-ai/Aria", load_in_4bit=True, llm_int8_skip_modules=["multihead_attn"] + ) processor = AutoProcessor.from_pretrained("rhymes-ai/Aria") prompt1 = "\n\nUSER: What's the difference of two images?\nASSISTANT:" @@ -432,24 +435,49 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): image1 = Image.open(requests.get(url1, stream=True).raw) image2 = Image.open(requests.get(url2, stream=True).raw) - inputs = processor( - images=[image1, image2, image1, image2], - text=[prompt1, prompt2, prompt3], - return_tensors="pt", - padding=True, - ).to(torch_device) - - model = model.eval() - - EXPECTED_OUTPUT = [ - "\n \nUSER: What's the difference of two images?\nASSISTANT: The difference between the two images is that one shows a dog standing on a grassy field, while", - "\nUSER: Describe the image.\nASSISTANT: The image features a brown and white dog sitting on a sidewalk. The dog is holding a small", - "\nUSER: Describe the image.\nASSISTANT: The image features a lone llama standing on a grassy hill. The llama is the", + # Create inputs + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": prompt1}, + {"type": "image"}, + {"type": "text", "text": prompt2}, + ], + }, + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": prompt3}, + ], + }, ] + prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages] + images = [[image1, image2], [image2]] + inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to( + device=model.device, dtype=model.dtype + ) + + EXPECTED_OUTPUT = { + "cpu": [ + "<|im_start|>user\n \n \n USER: What's the difference of two images?\n ASSISTANT: \n USER: Describe the image.\n ASSISTANT:<|im_end|>\n <|im_start|>assistant\n The first image features a cute, light-colored puppy sitting on a paved surface with", + "<|im_start|>user\n \n USER: Describe the image.\n ASSISTANT:<|im_end|>\n <|im_start|>assistant\n The image shows a young alpaca standing on a grassy hill. The alpaca has", + ], # cpu output + "cuda": [ + "<|im_start|>user\n \n \n USER: What's the difference of two images?\n ASSISTANT: \n USER: Describe the image.\n ASSISTANT:<|im_end|>\n <|im_start|>assistant\n The first image features a cute, light-colored puppy sitting on a paved surface with", + "<|im_start|>user\n \n USER: Describe the image.\n ASSISTANT:<|im_end|>\n <|im_start|>assistant\n The image shows a young alpaca standing on a patch of ground with some dry grass. The", + ], # cuda output + "xpu": [ + "<|im_start|>user\n \n \n USER: What's the difference of two images?\n ASSISTANT: \n USER: Describe the image.\n ASSISTANT:<|im_end|>\n <|im_start|>assistant\n The first image features a cute, light-colored puppy sitting on a paved surface with", + "<|im_start|>user\n \n USER: Describe the image.\n ASSISTANT:<|im_end|>\n <|im_start|>assistant\n The image shows a young alpaca standing on a grassy hill. The alpaca has", + ], # xpu output + } generate_ids = model.generate(**inputs, max_new_tokens=20) outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - self.assertEqual(outputs, EXPECTED_OUTPUT) + self.assertListEqual(outputs, EXPECTED_OUTPUT[model.device.type]) def test_tokenizer_integration(self): model_id = "rhymes-ai/Aria"