mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Fix llava_next
tests (#38813)
* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
b3b7789cbc
commit
e39172ecab
@ -392,7 +392,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt")
|
||||
inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt").to(torch_device)
|
||||
|
||||
# verify inputs against original implementation
|
||||
filepath = hf_hub_download(
|
||||
@ -415,11 +415,13 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
check_torch_load_is_safe()
|
||||
original_pixel_values = torch.load(filepath, map_location="cpu", weights_only=True)
|
||||
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())
|
||||
assert torch.allclose(
|
||||
original_pixel_values, inputs.pixel_values.to(device="cpu", dtype=original_pixel_values.dtype)
|
||||
)
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=100)
|
||||
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart is showing the performance of various models or systems across different metrics or datasets.\n\nThe chart is divided into several sections, each representing a different model or dataset. The axes represent different metrics or datasets, such as "MMM-Vet," "MMM-Bench," "L' # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart is showing the performance of various models or systems across different metrics or datasets.\n\nThe chart is divided into several sections, each representing a different model or dataset. The axes represent different metrics or datasets, such as "MMM-Vet," "MMM-Bench," "L'
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
@ -511,7 +513,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
# verify generation
|
||||
output = model.generate(**inputs, max_new_tokens=50)
|
||||
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows two deer, likely fawns, in a grassy area with trees in the background. The setting appears to be a forest or woodland, and the time of day seems to be either dawn or dusk, given the soft' # fmt: skip
|
||||
EXPECTED_DECODED_TEXT = "[INST] \nWhat is shown in this image? [/INST] The image shows two deer, likely fawns, in a grassy area with trees in the background. The setting appears to be a forest or woodland, and the photo is taken during what seems to be either dawn or dusk, given"
|
||||
self.assertEqual(
|
||||
self.processor.decode(output[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
|
Loading…
Reference in New Issue
Block a user