[chat-template] fix video loading (#37146)

* fix

* add video

* trigger

* push new iamges

* fix tests

* revert

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay 2025-04-02 11:27:50 +02:00 committed by GitHub
parent 800510c67b
commit 211e4dc9a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 247 additions and 23 deletions

View File

@ -7,5 +7,5 @@ ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv venv && uv pip install --no-cache-dir -U pip setuptools
RUN uv pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-deps timm accelerate --extra-index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing,tiktoken,num2words]"
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing,tiktoken,num2words,video]"
RUN uv pip uninstall transformers

View File

@ -1382,28 +1382,28 @@ class ProcessorMixin(PushToHubMixin):
batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]))
else:
for fname in video_fnames:
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
video = [np.array(load_image(image_fname)).T for image_fname in fname]
# create a 4D video because `load_video` always returns a 4D array
video = np.stack(video)
metadata = None
audios = None
logger.warning(
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
"If your model uses this metadata during processing, please load the whole video and let the model sample frames instead."
)
else:
video, metadata = load_video(
fname,
num_frames=mm_load_kwargs["num_frames"],
fps=mm_load_kwargs["video_fps"],
backend=mm_load_kwargs["video_load_backend"],
sample_indices_fn=mm_load_kwargs["sample_indices_fn"],
)
audios = load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"])
batch_audios.append(audios)
videos.append(video)
video_metadata.append(metadata)
batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]))
for fname in video_fnames:
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
video = [np.array(load_image(image_fname)).T for image_fname in fname]
# create a 4D video because `load_video` always returns a 4D array
video = np.stack(video)
metadata = None
logger.warning(
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
"If your model uses this metadata during processing, please load the whole video and let the model sample frames instead."
)
else:
video, metadata = load_video(
fname,
num_frames=mm_load_kwargs["num_frames"],
fps=mm_load_kwargs["video_fps"],
backend=mm_load_kwargs["video_load_backend"],
sample_indices_fn=mm_load_kwargs["sample_indices_fn"],
)
videos.append(video)
video_metadata.append(metadata)
# Currently all processors can accept nested list of batches, but not flat list of visuals
# So we'll make a batched list of images and let the processor handle it

View File

@ -18,6 +18,7 @@ import tempfile
import unittest
import pytest
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, Qwen2Tokenizer
from transformers.testing_utils import require_av, require_torch, require_vision
@ -326,3 +327,114 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertEqual(inputs[self.images_input_name].shape[0], 612)
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertEqual(inputs[self.images_input_name].shape[0], 800)
@require_av
def test_chat_template_video_custom_sampling(self):
"""
Tests that models can pass their custom callables to sample video indices.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def dummy_sample_indices_fn(metadata, **fn_kwargs):
# sample only the first two frame always
return [0, 1]
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
sample_indices_fn=dummy_sample_indices_fn,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 14400)
@require_av
def test_chat_template_video_special_processing(self):
"""
Tests that models can use their own preprocessing to preprocess conversations.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def _process_messages_for_chat_template(
conversation,
batch_images,
batch_videos,
batch_video_metadata,
**chat_template_kwargs,
):
# Let us just always return a dummy prompt
new_msg = [
[
{
"role": "user",
"content": [
{"type": "video"}, # no need to use path, video is loaded already by this moment
{"type": "text", "text": "Dummy prompt for preprocess testing"},
],
},
]
]
return new_msg
processor._process_messages_for_chat_template = _process_messages_for_chat_template
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1756800)

View File

@ -18,6 +18,7 @@ import tempfile
import unittest
import pytest
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, Qwen2Tokenizer
from transformers.testing_utils import require_av, require_torch, require_vision
@ -308,6 +309,117 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 71280)
@require_av
def test_chat_template_video_custom_sampling(self):
"""
Tests that models can pass their custom callables to sample video indices.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def dummy_sample_indices_fn(metadata, **fn_kwargs):
# sample only the first two frame always
return [0, 1]
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
sample_indices_fn=dummy_sample_indices_fn,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 14400)
@require_av
def test_chat_template_video_special_processing(self):
"""
Tests that models can use their own preprocessing to preprocess conversations.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def _process_messages_for_chat_template(
conversation,
batch_images,
batch_videos,
batch_video_metadata,
**chat_template_kwargs,
):
# Let us just always return a dummy prompt
new_msg = [
[
{
"role": "user",
"content": [
{"type": "video"}, # no need to use path, video is loaded already by this moment
{"type": "text", "text": "Dummy prompt for preprocess testing"},
],
},
]
]
return new_msg
processor._process_messages_for_chat_template = _process_messages_for_chat_template
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1756800)
def test_kwargs_overrides_custom_image_processor_kwargs(self):
processor_components = self.prepare_components()
processor_components["image_processor"] = self.get_component("image_processor")