mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix test_small_model_integration_test
This commit is contained in:
parent
1ccdce6bc1
commit
6de10e8d8b
@ -23,6 +23,7 @@ from transformers import (
|
||||
AriaForConditionalGeneration,
|
||||
AriaModel,
|
||||
AriaTextConfig,
|
||||
BitsAndBytesConfig,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
is_torch_available,
|
||||
@ -265,22 +266,32 @@ class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test(self):
|
||||
# Let's make sure we test the preprocessing to replace what is used
|
||||
model = AriaForConditionalGeneration.from_pretrained("rhymes-ai/Aria", load_in_4bit=True)
|
||||
model = AriaForConditionalGeneration.from_pretrained(
|
||||
"rhymes-ai/Aria",
|
||||
quantization_config=BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["multihead_attn"]),
|
||||
)
|
||||
|
||||
prompt = "<image>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:"
|
||||
prompt = "<|img|>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:"
|
||||
raw_image = Image.open(requests.get(IMAGE_OF_VIEW_URL, stream=True).raw)
|
||||
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt")
|
||||
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device, model.dtype)
|
||||
|
||||
EXPECTED_INPUT_IDS = torch.tensor(
|
||||
[[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]],
|
||||
) # fmt: skip
|
||||
non_img_tokens = [
|
||||
109, 3905, 2000, 93415, 4551, 1162, 901, 3894, 970, 2478, 1017, 19312, 2388, 1596, 1809, 970, 5449, 1235,
|
||||
3333, 93483, 109, 61081, 11984, 14800, 93415
|
||||
] # fmt: skip
|
||||
EXPECTED_INPUT_IDS = torch.tensor([[9] * 256 + non_img_tokens]).to(inputs["input_ids"].device)
|
||||
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20)
|
||||
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
|
||||
|
||||
decoded_output = self.processor.decode(output[0], skip_special_tokens=True)
|
||||
self.assertEqual(decoded_output, EXPECTED_DECODED_TEXT)
|
||||
|
||||
expected_output = Expectations(
|
||||
{
|
||||
("cuda", None): "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly,",
|
||||
("rocm", (9, 5)): "\n USER: What are the things I should be cautious about when I visit this place?\n ASSISTANT: When you visit this place, you should be cautious about the following things:\n\n- The"
|
||||
}
|
||||
).get_expectation()
|
||||
self.assertEqual(decoded_output, expected_output)
|
||||
|
||||
@slow
|
||||
@require_torch_large_accelerator
|
||||
|
Loading…
Reference in New Issue
Block a user