Fix test_small_model_integration_test

This commit is contained in:
remi-or 2025-06-30 05:24:25 -05:00
parent 1ccdce6bc1
commit 6de10e8d8b

View File

@ -23,6 +23,7 @@ from transformers import (
AriaForConditionalGeneration,
AriaModel,
AriaTextConfig,
BitsAndBytesConfig,
AutoProcessor,
AutoTokenizer,
is_torch_available,
@ -265,22 +266,32 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
@require_bitsandbytes
def test_small_model_integration_test(self):
# Let's make sure we test the preprocessing to replace what is used
model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True)
model = AriaForConditionalGeneration.from_pretrained(
"rhymes-ai/Aria",
quantization_config=BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["multihead_attn"]),
)
prompt = "<image>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:"
prompt = "<|img|>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:"
raw_image = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw)
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt")
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device, model.dtype)
EXPECTED_INPUT_IDS = torch.tensor(
[[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]],
) # fmt: skip
non_img_tokens = [
109, 3905, 2000, 93415, 4551, 1162, 901, 3894, 970, 2478, 1017, 19312, 2388, 1596, 1809, 970, 5449, 1235,
3333, 93483, 109, 61081, 11984, 14800, 93415
] # fmt: skip
EXPECTED_INPUT_IDS = torch.tensor([[9] * 256 + non_img_tokens]).to(inputs["input_ids"].device)
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
decoded_output = self.processor.decode(output[0], skip_special_tokens=True)
self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT)
expected_output = Expectations(
{
("cuda", None): "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly,",
("rocm", (9, 5)): "\n USER: What are the things I should be cautious about when I visit this place?\n ASSISTANT: When you visit this place, you should be cautious about the following things:\n\n- The"
}
).get_expectation()
self.assertEqual(decoded_output, expected_output)
@slow
@require_torch_large_accelerator