diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index 93142f1da68..cb573913e4a 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -392,7 +392,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): load_in_4bit=True, ) - inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt") + inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt").to(torch_device) # verify inputs against original implementation filepath = hf_hub_download( @@ -415,11 +415,13 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): ) check_torch_load_is_safe() original_pixel_values = torch.load(filepath, map_location="cpu", weights_only=True) - assert torch.allclose(original_pixel_values, inputs.pixel_values.half()) + assert torch.allclose( + original_pixel_values, inputs.pixel_values.to(device="cpu", dtype=original_pixel_values.dtype) + ) # verify generation output = model.generate(**inputs, max_new_tokens=100) - EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart is showing the performance of various models or systems across different metrics or datasets.\n\nThe chart is divided into several sections, each representing a different model or dataset. The axes represent different metrics or datasets, such as "MMM-Vet," "MMM-Bench," "L' # fmt: skip + EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart is showing the performance of various models or systems across different metrics or datasets.\n\nThe chart is divided into several sections, each representing a different model or dataset. The axes represent different metrics or datasets, such as "MMM-Vet," "MMM-Bench," "L' self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), @@ -511,7 +513,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): # verify generation output = model.generate(**inputs, max_new_tokens=50) - EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows two deer, likely fawns, in a grassy area with trees in the background. The setting appears to be a forest or woodland, and the time of day seems to be either dawn or dusk, given the soft' # fmt: skip + EXPECTED_DECODED_TEXT = "[INST] \nWhat is shown in this image? [/INST] The image shows two deer, likely fawns, in a grassy area with trees in the background. The setting appears to be a forest or woodland, and the photo is taken during what seems to be either dawn or dusk, given" self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT,