[chat templates} support loading audio from video (#36955)

* add audio from video

* typos

* delete print

* comments
This commit is contained in:
Raushan Turganbay 2025-03-27 14:46:11 +01:00 committed by GitHub
parent c7bc79bd2a
commit e97c760006
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 129 additions and 58 deletions

View File

@ -23,7 +23,7 @@ import sys
import typing
import warnings
from pathlib import Path
from typing import Any, Callable, Optional, TypedDict, Union
from typing import Any, Callable, Dict, List, Optional, TypedDict, Union
import numpy as np
import typing_extensions
@ -386,14 +386,10 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False):
return_assistant_tokens_mask: Optional[bool] = False
class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False):
class ChatTemplateLoadKwargs(TypedDict, total=False):
"""
Keyword arguments for processor chat templates.
Keyword arguments used to load multimodal data in processor chat templates.
tokenize (`bool`, *optional*, defaults to `False`):
Whether to tokenize the output or not.
return_dict (`bool`, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
video_load_backend (`str`, *optional*, defaults to `"pyav"`):
@ -415,13 +411,26 @@ class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False):
return np.linspace(start_idx, end_idx, num_frames, dtype=int)
"""
tokenize: Optional[bool] = False
return_dict: Optional[bool] = False
num_frames: Optional[int] = None
video_load_backend: Optional[str] = "pyav"
video_fps: Optional[int] = None
sampling_rate: Optional[int] = 16_000
sample_indices_fn: Optional[Callable] = None
load_audio_from_video: Optional[bool] = False
class ProcessorChatTemplateKwargs(ChatTemplateLoadKwargs, TokenizerChatTemplateKwargs, total=False):
"""
Keyword arguments for processor's `apply_chat_template`.
tokenize (`bool`, *optional*, defaults to `False`):
Whether to tokenize the output or not.
return_dict (`bool`, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
"""
tokenize: Optional[bool] = False
return_dict: Optional[bool] = False
class AllKwargsForChatTemplate(
@ -1236,11 +1245,11 @@ class ProcessorMixin(PushToHubMixin):
def _process_messages_for_chat_template(
self,
conversation: list[list[dict[str, str]]],
batch_images: list[ImageInput],
batch_videos: list[VideoInput],
batch_video_metadata: list[list[dict[str, any]]],
**chat_template_kwargs: Unpack[AllKwargsForChatTemplate],
conversation: List[List[Dict[str, str]]],
batch_images: List[ImageInput],
batch_videos: List[VideoInput],
batch_video_metadata: List[List[Dict[str, any]]],
**mm_load_kwargs: Unpack[ChatTemplateLoadKwargs],
):
"""
Used within `apply_chat_template` when a model has a special way to process conversation history. For example,
@ -1311,18 +1320,18 @@ class ProcessorMixin(PushToHubMixin):
)
# Fill two sets of kwargs that should be used by tokenizer's `apply_chat_template`
# and for multimodal chat template
# and for multimodal data loading. Everything else will be used in `__call__`
tokenizer_template_kwargs = {}
for tokenizer_key in TokenizerChatTemplateKwargs.__annotations__.keys():
tokenizer_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None)
value = kwargs.pop(tokenizer_key, tokenizer_value)
default_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None)
value = kwargs.pop(tokenizer_key, default_value)
tokenizer_template_kwargs[tokenizer_key] = value
chat_template_kwargs = {}
for key in ProcessorChatTemplateKwargs.__annotations__.keys():
processor_value = getattr(ProcessorChatTemplateKwargs, key, None)
value = kwargs.pop(key, processor_value)
chat_template_kwargs[key] = value
mm_load_kwargs = {}
for mm_load_key in ChatTemplateLoadKwargs.__annotations__.keys():
default_value = getattr(ChatTemplateLoadKwargs, mm_load_key, None)
value = kwargs.pop(mm_load_key, default_value)
mm_load_kwargs[mm_load_key] = value
if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
@ -1333,13 +1342,8 @@ class ProcessorMixin(PushToHubMixin):
is_batched = False
conversations = [conversation]
num_frames = chat_template_kwargs.get("num_frames")
video_fps = chat_template_kwargs.get("video_fps")
video_load_backend = chat_template_kwargs.get("video_load_backend")
tokenize = chat_template_kwargs.get("tokenize")
return_dict = chat_template_kwargs.get("return_dict")
sample_indices_fn = chat_template_kwargs.get("sample_indices_fn")
sampling_rate = chat_template_kwargs.pop("sampling_rate")
tokenize = kwargs.pop("tokenize", False)
return_dict = kwargs.pop("return_dict", False)
if tokenize:
batch_images, batch_videos = [], []
@ -1369,31 +1373,37 @@ class ProcessorMixin(PushToHubMixin):
if key in vision_info and vision_info["type"] == "video"
]
# Audio models do not accept nested list of audios (yet!)
for fname in audio_fnames:
batch_audios.append(load_audio(fname, sampling_rate=sampling_rate))
for fname in image_fnames:
images.append(load_image(fname))
for fname in video_fnames:
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
video = [np.array(load_image(image_fname)).T for image_fname in fname]
# create a 4D video because `load_video` always returns a 4D array
video = np.stack(video)
metadata = None
logger.warning(
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
"If you model applies special processing based on metadata, please load the whole video and let the model sample frames."
)
else:
video, metadata = load_video(
fname,
num_frames=num_frames,
fps=video_fps,
backend=video_load_backend,
sample_indices_fn=sample_indices_fn,
)
videos.append(video)
video_metadata.append(metadata)
# Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list
if not mm_load_kwargs["load_audio_from_video"]:
for fname in audio_fnames:
batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]))
else:
for fname in video_fnames:
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
video = [np.array(load_image(image_fname)).T for image_fname in fname]
# create a 4D video because `load_video` always returns a 4D array
video = np.stack(video)
metadata = None
audios = None
logger.warning(
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
"If your model uses this metadata during processing, please load the whole video and let the model sample frames instead."
)
else:
video, metadata = load_video(
fname,
num_frames=mm_load_kwargs["num_frames"],
fps=mm_load_kwargs["video_fps"],
backend=mm_load_kwargs["video_load_backend"],
sample_indices_fn=mm_load_kwargs["sample_indices_fn"],
)
audios = load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"])
batch_audios.append(audios)
videos.append(video)
video_metadata.append(metadata)
# Currently all processors can accept nested list of batches, but not flat list of visuals
# So we'll make a batched list of images and let the processor handle it
@ -1409,7 +1419,7 @@ class ProcessorMixin(PushToHubMixin):
batch_images=batch_images,
batch_videos=batch_videos,
batch_video_metadata=batch_video_metadata,
**chat_template_kwargs,
**mm_load_kwargs,
)
prompt = self.tokenizer.apply_chat_template(
@ -1438,7 +1448,7 @@ class ProcessorMixin(PushToHubMixin):
text=prompt,
images=batch_images if batch_images else None,
videos=batch_videos if batch_videos else None,
audios=batch_audios if batch_audios else None,
audio=batch_audios if batch_audios else None,
**kwargs,
)
if return_dict:

View File

@ -1097,10 +1097,7 @@ class ProcessorTesterMixin:
{
"role": "user",
"content": [
{
"type": "video",
"path": video_file_path,
},
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
@ -1189,6 +1186,70 @@ class ProcessorTesterMixin:
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243)
@require_librosa
@require_av
def test_audio_chat_template_from_video(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest(f"{self.processor_class} does not suport video inputs")
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "Which of these animals is making the sound?"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is a cow."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3",
},
{"type": "text", "text": "Is it the same sound?"},
],
},
]
formatted_prompt = processor.apply_chat_template([messages], add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 1) # batch size=1
out_dict = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="np",
load_audio_from_video=True,
)
self.assertTrue(self.audio_input_name in out_dict)
self.assertTrue(self.video_input_name in out_dict)
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 1) # batch-size=1
self.assertEqual(len(out_dict["attention_mask"]), 1) # batch-size=1
self.assertEqual(len(out_dict[self.audio_input_name]), 2) # 2 audios in the conversation
self.assertEqual(len(out_dict[self.video_input_name]), 1) # 1 video in the conversation
@require_librosa
def test_audio_chat_template_single(self):
processor = self.get_processor()