Expectations for llava_next_video

This commit is contained in:
remi-or 2025-06-27 07:56:17 -05:00
parent e8e0c76162
commit b7b9ae39e8

View File

@ -29,6 +29,7 @@ from transformers import (
is_vision_available,
)
from transformers.testing_utils import (
Expectations,
cleanup,
require_bitsandbytes,
require_torch,
@ -378,12 +379,16 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
# verify generation
output = model.generate(**inputs, do_sample=False, max_new_tokens=40)
EXPECTED_DECODED_TEXT = (
"USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems", # cuda output
"USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while wearing a pair of glasses that are too large for them. The glasses are", # xpu output
)
expected_decoded_text = Expectations(
{
("cuda", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems",
("xpu", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while wearing a pair of glasses that are too large for them. The glasses are",
("rocm", (9, 5)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and adorable behavior of the young child. The child is seen reading a book, but instead of turning the pages like one would typically do, they",
}
).get_expectation() # fmt: off
self.assertTrue(self.processor.decode(output[0], skip_special_tokens=True) in EXPECTED_DECODED_TEXT)
decoded_text = self.processor.decode(output[0], skip_special_tokens=True)
self.assertEqual(decoded_text, expected_decoded_text)
@slow
@require_bitsandbytes
@ -400,15 +405,17 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
).to(torch_device)
output = model.generate(**inputs, do_sample=False, max_new_tokens=20)
decoded_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_DECODED_TEXT = [
'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a',
'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a'
] # fmt: skip
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
expected_decoded_text = Expectations(
{
("cuda", None): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a",
("rocm", (9, 5)): "USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and adorable behavior of the young child. The",
}
).get_expectation() # fmt: off
EXPECTED_DECODED_TEXT = [expected_decoded_text, expected_decoded_text]
self.assertEqual(decoded_text, EXPECTED_DECODED_TEXT)
@slow
@require_bitsandbytes
@ -435,8 +442,15 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
# verify generation
output = model.generate(**inputs, do_sample=False, max_new_tokens=50)
EXPECTED_DECODED_TEXT = 'USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a machine learning model\'s performance on a task, likely related to natural language processing or text understanding. It shows a scatter plot with two axes, one labeled "BLIP-2"' # fmt: skip
self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT)
EXPECTED_DECODED_TEXT = Expectations(
{
("rocm", (9, 5)): "USER: \nWhat is shown in this image? ASSISTANT: The image displays a chart that appears to be a comparison of different models or versions of a machine learning (ML) model, likely a neural network, based on their performance on a task or dataset. The chart is a scatter plot with axes labeled",
("cuda", None): 'USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a machine learning model\'s performance on a task, likely related to natural language processing or text understanding. It shows a scatter plot with two axes, one labeled "BLIP-2"',
}
).get_expectation() # fmt: off
decoded_text = self.processor.decode(output[0], skip_special_tokens=True)
self.assertEqual(decoded_text, EXPECTED_DECODED_TEXT)
@slow
@require_bitsandbytes