mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
[video processor] fix tests (#38104)
* fix tests * delete * fix one more test * fix qwen + some tests are failing irrespective of `VideoProcessor` * delete file
This commit is contained in:
parent
9b5ce556aa
commit
aaf224d570
@ -46,12 +46,15 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
|
("instructblip", "InstructBlipVideoVideoProcessor"),
|
||||||
("instructblipvideo", "InstructBlipVideoVideoProcessor"),
|
("instructblipvideo", "InstructBlipVideoVideoProcessor"),
|
||||||
("internvl", "InternVLVideoProcessor"),
|
("internvl", "InternVLVideoProcessor"),
|
||||||
("llava_next_video", "LlavaNextVideoVideoProcessor"),
|
("llava_next_video", "LlavaNextVideoVideoProcessor"),
|
||||||
("llava_onevision", "LlavaOnevisionVideoProcessor"),
|
("llava_onevision", "LlavaOnevisionVideoProcessor"),
|
||||||
("qwen2_5_vl", "Qwen2_5_VLVideoProcessor"),
|
("qwen2_5_omni", "Qwen2VLVideoProcessor"),
|
||||||
|
("qwen2_5_vl", "Qwen2VLVideoProcessor"),
|
||||||
("qwen2_vl", "Qwen2VLVideoProcessor"),
|
("qwen2_vl", "Qwen2VLVideoProcessor"),
|
||||||
|
("smolvlm", "SmolVLMVideoProcessor"),
|
||||||
("video_llava", "VideoLlavaVideoProcessor"),
|
("video_llava", "VideoLlavaVideoProcessor"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -156,21 +156,17 @@ class VideoLlavaProcessor(ProcessorMixin):
|
|||||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||||
- **pixel_values_videos** -- Pixel values to be fed to a model. Returned when `videos` is not `None`.
|
- **pixel_values_videos** -- Pixel values to be fed to a model. Returned when `videos` is not `None`.
|
||||||
"""
|
"""
|
||||||
data = {}
|
|
||||||
if images is not None:
|
|
||||||
encoded_images = self.image_processor(images=images, return_tensors=return_tensors)
|
|
||||||
data.update(encoded_images)
|
|
||||||
|
|
||||||
if videos is not None:
|
|
||||||
encoded_videos = self.video_processor(videos=videos, return_tensors=return_tensors)
|
|
||||||
data.update(encoded_videos)
|
|
||||||
|
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
text = [text]
|
text = [text]
|
||||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||||
|
|
||||||
if encoded_images is not None:
|
data = {}
|
||||||
|
if images is not None:
|
||||||
|
encoded_images = self.image_processor(images=images, return_tensors=return_tensors)
|
||||||
|
data.update(encoded_images)
|
||||||
|
|
||||||
height, width = get_image_size(to_numpy_array(encoded_images.get("pixel_values_images")[0]))
|
height, width = get_image_size(to_numpy_array(encoded_images.get("pixel_values_images")[0]))
|
||||||
num_image_tokens = (height // self.patch_size) * (width // self.patch_size)
|
num_image_tokens = (height // self.patch_size) * (width // self.patch_size)
|
||||||
num_image_tokens += self.num_additional_image_tokens
|
num_image_tokens += self.num_additional_image_tokens
|
||||||
@ -178,7 +174,10 @@ class VideoLlavaProcessor(ProcessorMixin):
|
|||||||
num_image_tokens -= 1
|
num_image_tokens -= 1
|
||||||
text = [sample.replace(self.image_token, self.image_token * num_image_tokens) for sample in text]
|
text = [sample.replace(self.image_token, self.image_token * num_image_tokens) for sample in text]
|
||||||
|
|
||||||
if encoded_videos is not None:
|
if videos is not None:
|
||||||
|
encoded_videos = self.video_processor(videos=videos, return_tensors=return_tensors)
|
||||||
|
data.update(encoded_videos)
|
||||||
|
|
||||||
one_video = encoded_videos.get("pixel_values_videos")[0]
|
one_video = encoded_videos.get("pixel_values_videos")[0]
|
||||||
if isinstance(encoded_videos.get("pixel_values_videos")[0], (list, tuple)):
|
if isinstance(encoded_videos.get("pixel_values_videos")[0], (list, tuple)):
|
||||||
one_video = np.array(one_video)
|
one_video = np.array(one_video)
|
||||||
|
@ -415,7 +415,7 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
"llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True, cache_dir="./"
|
"llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True, cache_dir="./"
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs = self.processor(self.prompt_video, videos=self.video, return_tensors="pt")
|
inputs = self.processor(text=self.prompt_video, videos=self.video, return_tensors="pt")
|
||||||
# verify single forward pass
|
# verify single forward pass
|
||||||
inputs = inputs.to(torch_device)
|
inputs = inputs.to(torch_device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -438,7 +438,7 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
[self.prompt_video, self.prompt_video],
|
text=[self.prompt_video, self.prompt_video],
|
||||||
videos=[self.video, self.video],
|
videos=[self.video, self.video],
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
@ -465,7 +465,7 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
[self.prompt_image, self.prompt_video],
|
text=[self.prompt_image, self.prompt_video],
|
||||||
images=self.image,
|
images=self.image,
|
||||||
videos=self.video,
|
videos=self.video,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@ -491,7 +491,7 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
inputs_batched = self.processor(
|
inputs_batched = self.processor(
|
||||||
[self.prompt_video, self.prompt_image],
|
text=[self.prompt_video, self.prompt_image],
|
||||||
images=[self.image],
|
images=[self.image],
|
||||||
videos=[self.video],
|
videos=[self.video],
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
|
@ -648,7 +648,12 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
|||||||
self.messages[0],
|
self.messages[0],
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "The sound is glass shattering, and the dog appears to be a Labrador Retriever.",
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "The sound is glass shattering, and the dog appears to be a Labrador Retriever.",
|
||||||
|
}
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@ -687,7 +692,12 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
|||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.",
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.",
|
||||||
|
}
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@ -697,7 +707,7 @@ class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):
|
|||||||
audio, _ = librosa.load(BytesIO(urlopen(audio_url).read()), sr=self.processor.feature_extractor.sampling_rate)
|
audio, _ = librosa.load(BytesIO(urlopen(audio_url).read()), sr=self.processor.feature_extractor.sampling_rate)
|
||||||
|
|
||||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
inputs = self.processor(text=[text], audio=[audio], return_tensors="pt", padding=True).to(torch_device)
|
inputs = self.processor(text=text, audio=[audio], return_tensors="pt", padding=True).to(torch_device)
|
||||||
|
|
||||||
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False)
|
output = model.generate(**inputs, thinker_temperature=0, thinker_do_sample=False)
|
||||||
|
|
||||||
|
@ -466,7 +466,7 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
|
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
|
||||||
)
|
)
|
||||||
video_file = np.load(video_file)
|
video_file = np.load(video_file)
|
||||||
inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device)
|
inputs = self.processor(text=prompt, videos=video_file, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
EXPECTED_INPUT_IDS = torch.tensor([1, 3148, 1001, 29901, 29871, 13, 11008, 338, 445, 4863, 2090, 1460, 29973, 319, 1799, 9047, 13566, 29901], device=torch_device) # fmt: skip
|
EXPECTED_INPUT_IDS = torch.tensor([1, 3148, 1001, 29901, 29871, 13, 11008, 338, 445, 4863, 2090, 1460, 29973, 319, 1799, 9047, 13566, 29901], device=torch_device) # fmt: skip
|
||||||
non_video_inputs = inputs["input_ids"][inputs["input_ids"] != 32001]
|
non_video_inputs = inputs["input_ids"][inputs["input_ids"] != 32001]
|
||||||
@ -496,9 +496,9 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
image = Image.open(requests.get(url, stream=True).raw)
|
image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
inputs = self.processor(prompts, images=[image], videos=[video_file], padding=True, return_tensors="pt").to(
|
inputs = self.processor(
|
||||||
torch_device
|
text=prompts, images=[image], videos=[video_file], padding=True, return_tensors="pt"
|
||||||
)
|
).to(torch_device)
|
||||||
output = model.generate(**inputs, do_sample=False, max_new_tokens=20)
|
output = model.generate(**inputs, do_sample=False, max_new_tokens=20)
|
||||||
|
|
||||||
EXPECTED_DECODED_TEXT = [
|
EXPECTED_DECODED_TEXT = [
|
||||||
@ -522,7 +522,7 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
|
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
|
||||||
)
|
)
|
||||||
video_file = np.load(video_file)
|
video_file = np.load(video_file)
|
||||||
inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
|
inputs = self.processor(text=prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
|
||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
|
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
|
||||||
EXPECTED_DECODED_TEXT = "USER: \nDescribe the video in details. ASSISTANT: The video features a young child sitting on a bed, holding a book and reading it. " \
|
EXPECTED_DECODED_TEXT = "USER: \nDescribe the video in details. ASSISTANT: The video features a young child sitting on a bed, holding a book and reading it. " \
|
||||||
@ -554,7 +554,7 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo_2.npy", repo_type="dataset")
|
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo_2.npy", repo_type="dataset")
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs = processor(prompts, videos=[video_1, video_2], return_tensors="pt", padding=True).to(torch_device)
|
inputs = processor(text=prompts, videos=[video_1, video_2], return_tensors="pt", padding=True).to(torch_device)
|
||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=20)
|
output = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
|
||||||
|
@ -71,8 +71,8 @@ class BaseVideoProcessorTester(unittest.TestCase):
|
|||||||
|
|
||||||
# Test a list of videos is converted to a list of 1 video
|
# Test a list of videos is converted to a list of 1 video
|
||||||
video = get_random_video(16, 32)
|
video = get_random_video(16, 32)
|
||||||
video = [PIL.Image.fromarray(frame) for frame in video]
|
pil_video = [PIL.Image.fromarray(frame) for frame in video]
|
||||||
videos_list = make_batched_videos(video)
|
videos_list = make_batched_videos(pil_video)
|
||||||
self.assertIsInstance(videos_list, list)
|
self.assertIsInstance(videos_list, list)
|
||||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||||
@ -80,8 +80,8 @@ class BaseVideoProcessorTester(unittest.TestCase):
|
|||||||
|
|
||||||
# Test a nested list of videos is not modified
|
# Test a nested list of videos is not modified
|
||||||
video = get_random_video(16, 32)
|
video = get_random_video(16, 32)
|
||||||
video = [PIL.Image.fromarray(frame) for frame in video]
|
pil_video = [PIL.Image.fromarray(frame) for frame in video]
|
||||||
videos = [video, video]
|
videos = [pil_video, pil_video]
|
||||||
videos_list = make_batched_videos(videos)
|
videos_list = make_batched_videos(videos)
|
||||||
self.assertIsInstance(videos_list, list)
|
self.assertIsInstance(videos_list, list)
|
||||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||||
|
Loading…
Reference in New Issue
Block a user