Fix llava_next tests (#38813)

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2025-06-13 15:19:41 +02:00 committed by GitHub
parent b3b7789cbc
commit e39172ecab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -392,7 +392,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
load_in_4bit=True,
)
inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt")
inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt").to(torch_device)
# verify inputs against original implementation
filepath = hf_hub_download(
@ -415,11 +415,13 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
)
check_torch_load_is_safe()
original_pixel_values = torch.load(filepath, map_location="cpu", weights_only=True)
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())
assert torch.allclose(
original_pixel_values, inputs.pixel_values.to(device="cpu", dtype=original_pixel_values.dtype)
)
# verify generation
output = model.generate(**inputs, max_new_tokens=100)
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart is showing the performance of various models or systems across different metrics or datasets.\n\nThe chart is divided into several sections, each representing a different model or dataset. The axes represent different metrics or datasets, such as "MMM-Vet," "MMM-Bench," "L' # fmt: skip
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart is showing the performance of various models or systems across different metrics or datasets.\n\nThe chart is divided into several sections, each representing a different model or dataset. The axes represent different metrics or datasets, such as "MMM-Vet," "MMM-Bench," "L'
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
@ -511,7 +513,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
# verify generation
output = model.generate(**inputs, max_new_tokens=50)
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows two deer, likely fawns, in a grassy area with trees in the background. The setting appears to be a forest or woodland, and the time of day seems to be either dawn or dusk, given the soft' # fmt: skip
EXPECTED_DECODED_TEXT = "[INST] \nWhat is shown in this image? [/INST] The image shows two deer, likely fawns, in a grassy area with trees in the background. The setting appears to be a forest or woodland, and the photo is taken during what seems to be either dawn or dusk, given"
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,