diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py index 731844baf90..5bd2ff1bdbc 100644 --- a/tests/models/aria/test_modeling_aria.py +++ b/tests/models/aria/test_modeling_aria.py @@ -23,6 +23,7 @@ from transformers import ( AriaForConditionalGeneration, AriaModel, AriaTextConfig, + BitsAndBytesConfig, AutoProcessor, AutoTokenizer, is_torch_available, @@ -265,22 +266,32 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): @require_bitsandbytes def test_small_model_integration_test(self): # Let's make sure we test the preprocessing to replace what is used - model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True) + model = AriaForConditionalGeneration.from_pretrained( + "rhymes-ai/Aria", + quantization_config=BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["multihead_attn"]), + ) - prompt = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:" + prompt = "<|img|>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:" raw_image = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw) - inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt") + inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device, model.dtype) - EXPECTED_INPUT_IDS = torch.tensor( - [[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]], - ) # fmt: skip + non_img_tokens = [ + 109, 3905, 2000, 93415, 4551, 1162, 901, 3894, 970, 2478, 1017, 19312, 2388, 1596, 1809, 970, 5449, 1235, + 3333, 93483, 109, 61081, 11984, 14800, 93415 + ] # fmt: skip + EXPECTED_INPUT_IDS = torch.tensor([[9] * 256 + non_img_tokens]).to(inputs["input_ids"].device) self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) output = model.generate(**inputs, max_new_tokens=20) - EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip - decoded_output = self.processor.decode(output[0], skip_special_tokens=True) - self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT) + + expected_output = Expectations( + { + ("cuda", None): "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly,", + ("rocm", (9, 5)): "\n USER: What are the things I should be cautious about when I visit this place?\n ASSISTANT: When you visit this place, you should be cautious about the following things:\n\n- The" + } + ).get_expectation() + self.assertEqual(decoded_output, expected_output) @slow @require_torch_large_accelerator