[chat-template] Unify tests and clean up 🧼 (#37275)

* fix tests and some clean up

* make one general test for each modality

* remove redundant merging of kwargs

* edge cases

* dont enforce slow when reloading

* fix gemma3 tests

* has to adapt llama 4 after rebase

* remove also from overriden tests

* should be green now
This commit is contained in:
Raushan Turganbay 2025-04-10 14:42:32 +02:00 committed by GitHub
parent 10144ff116
commit 1ae8d54b04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 389 additions and 1112 deletions

View File

@ -181,35 +181,6 @@ processed_chat = processor.apply_chat_template(
print(processed_chat.keys())
```
</hfoption>
<hfoption id="custom frame sampling">
Some models don't sample frames *uniformly* and require more complex logic to determine which frames to use. For example, the model may have an *adaptive frame selection* or if the model prioritizes *key moments* in a video rather than evenly spaced frames.
If a model has a different sampling strategy, you can write a function that customizes frame selection. The function should include the following requirements.
- Use the `sample_indices_fn` parameter to pass a callable function for sampling.
- If provided, this function *overrides* the standard `num_frames` and `fps` parameters.
- The function receives all the parameters passed to `load_video` and must return valid frame indices to sample from.
An example function is shown below. This gives you full control over frame selection, making the model more adaptable to different video scenarios.
```py
def sample_indices_fn(metadata, **kwargs):
# samples only the first and the second frame
return [0, 1]
processed_chat = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
sample_indices_fn=sample_indices_fn,
video_load_backend="decord",
)
print(processed_chat.keys())
```
</hfoption>
<hfoption id="list of image frames">

View File

@ -20,10 +20,13 @@ import copy
from datetime import timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import (
ImageInput,
VideoInput,
load_video,
make_batched_videos,
make_nested_list_of_images,
)
@ -425,32 +428,44 @@ class SmolVLMProcessor(ProcessorMixin):
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(image_processor_input_names + tokenizer_input_names))
# Add model-specific video sampling method when applying the template
def apply_chat_template(
# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
def _load_video_for_model(
self,
conversation,
max_frames=None,
target_fps=None,
skip_secs=1,
video_load_backend="pyav",
sample_indices_fn=None,
**kwargs,
):
max_frames = self.default_max_frames if max_frames is None else max_frames
target_fps = self.default_fps if target_fps is None else target_fps
video: Union[str, "VideoInput"],
num_frames: Optional[int] = None,
fps: Optional[int] = None,
backend: str = "opencv",
skip_secs: int = 0.0,
) -> np.array:
"""
Loads `video` to a numpy array.
Args:
video (`str` or `VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
backend (`str`, *optional*, defaults to `"opencv"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
Returns:
Tuple[`np.array`, Dict]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- Metadata dictionary.
"""
max_frames = self.default_max_frames if num_frames is None else num_frames
target_fps = self.default_fps if fps is None else fps
def sample_indices_fn_func(metadata, **fn_kwargs):
return smolvlm_sample_indices_fn(
metadata, max_frames=max_frames, target_fps=target_fps, skip_secs=skip_secs, **fn_kwargs
)
# word of caution- we are blindly overriding a callable kwarg here.
# typed kwargs would be a way to avoid that @molbap
if not sample_indices_fn:
sample_indices_fn = sample_indices_fn_func
return super().apply_chat_template(
conversation, video_load_backend=video_load_backend, sample_indices_fn=sample_indices_fn, **kwargs
)
video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func)
return video, metadata
__all__ = ["SmolVLMProcessor"]

View File

@ -23,7 +23,7 @@ import sys
import typing
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, TypedDict, Union
from typing import Any, Dict, List, Optional, TypedDict, Union
import numpy as np
import typing_extensions
@ -415,7 +415,6 @@ class ChatTemplateLoadKwargs(TypedDict, total=False):
video_load_backend: Optional[str] = "pyav"
video_fps: Optional[int] = None
sampling_rate: Optional[int] = 16_000
sample_indices_fn: Optional[Callable] = None
load_audio_from_video: Optional[bool] = False
@ -435,7 +434,16 @@ class ProcessorChatTemplateKwargs(ChatTemplateLoadKwargs, TokenizerChatTemplateK
class AllKwargsForChatTemplate(
TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ProcessorChatTemplateKwargs
): ...
):
processor_kwargs: ProcessingKwargs = {
**ProcessingKwargs.__annotations__,
}
mm_load_kwargs: ChatTemplateLoadKwargs = {
**TextKwargs.__annotations__,
}
template_kwargs: ProcessorChatTemplateKwargs = {
**ProcessorChatTemplateKwargs.__annotations__,
}
class ProcessorMixin(PushToHubMixin):
@ -1315,19 +1323,20 @@ class ProcessorMixin(PushToHubMixin):
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
)
# Fill two sets of kwargs that should be used by tokenizer's `apply_chat_template`
# and for multimodal data loading. Everything else will be used in `__call__`
tokenizer_template_kwargs = {}
for tokenizer_key in TokenizerChatTemplateKwargs.__annotations__.keys():
default_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None)
value = kwargs.pop(tokenizer_key, default_value)
tokenizer_template_kwargs[tokenizer_key] = value
# Fill sets of kwargs that should be used by different parts of template
processed_kwargs = {
"processor_kwargs": {},
"mm_load_kwargs": {},
"template_kwargs": {},
}
mm_load_kwargs = {}
for mm_load_key in ChatTemplateLoadKwargs.__annotations__.keys():
default_value = getattr(ChatTemplateLoadKwargs, mm_load_key, None)
value = kwargs.pop(mm_load_key, default_value)
mm_load_kwargs[mm_load_key] = value
for kwarg_type in processed_kwargs:
for key in AllKwargsForChatTemplate.__annotations__[kwarg_type].__annotations__.keys():
kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[kwarg_type]
default_value = getattr(kwarg_type_defaults, key, None)
value = kwargs.pop(key, default_value)
if value is not None and not isinstance(value, dict):
processed_kwargs[kwarg_type][key] = value
if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
@ -1338,8 +1347,9 @@ class ProcessorMixin(PushToHubMixin):
is_batched = False
conversations = [conversation]
tokenize = kwargs.pop("tokenize", False)
return_dict = kwargs.pop("return_dict", False)
tokenize = processed_kwargs["template_kwargs"].pop("tokenize", False)
return_dict = processed_kwargs["template_kwargs"].pop("return_dict", False)
mm_load_kwargs = processed_kwargs["mm_load_kwargs"]
if tokenize:
batch_images, batch_videos = [], []
@ -1382,7 +1392,7 @@ class ProcessorMixin(PushToHubMixin):
for fname in video_fnames:
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
video = [np.array(load_image(image_fname)).T for image_fname in fname]
video = [np.array(load_image(image_fname)) for image_fname in fname]
# create a 4D video because `load_video` always returns a 4D array
video = np.stack(video)
metadata = None
@ -1391,12 +1401,13 @@ class ProcessorMixin(PushToHubMixin):
"If your model uses this metadata during processing, please load the whole video and let the model sample frames instead."
)
else:
video, metadata = load_video(
# TODO: raushan, should be `self.video_processor.load_video_for_model` when API is added
video, metadata = self._load_video_for_model(
fname,
num_frames=mm_load_kwargs["num_frames"],
fps=mm_load_kwargs["video_fps"],
num_frames=mm_load_kwargs.get("num_frames", None),
fps=mm_load_kwargs.get("video_fps", None),
backend=mm_load_kwargs["video_load_backend"],
sample_indices_fn=mm_load_kwargs["sample_indices_fn"],
**kwargs,
)
videos.append(video)
video_metadata.append(metadata)
@ -1415,7 +1426,7 @@ class ProcessorMixin(PushToHubMixin):
batch_images=batch_images,
batch_videos=batch_videos,
batch_video_metadata=batch_video_metadata,
**mm_load_kwargs,
**processed_kwargs["mm_load_kwargs"],
)
prompt = self.tokenizer.apply_chat_template(
@ -1423,7 +1434,7 @@ class ProcessorMixin(PushToHubMixin):
chat_template=chat_template,
tokenize=False,
return_dict=False,
**tokenizer_template_kwargs,
**processed_kwargs["template_kwargs"],
)
if not is_batched:
@ -1438,14 +1449,14 @@ class ProcessorMixin(PushToHubMixin):
# without actionable solution for users
single_prompt = prompt[0] if is_batched else prompt
if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token):
kwargs["add_special_tokens"] = False
processed_kwargs["processor_kwargs"]["add_special_tokens"] = False
out = self(
text=prompt,
images=batch_images if batch_images else None,
videos=batch_videos if batch_videos else None,
audio=batch_audios if batch_audios else None,
**kwargs,
**processed_kwargs["processor_kwargs"],
)
if return_dict:
return out
@ -1453,6 +1464,37 @@ class ProcessorMixin(PushToHubMixin):
return out["input_ids"]
return prompt
# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
# Keep private so we can simply remove when needed
def _load_video_for_model(
self,
video: Union[str, "VideoInput"],
num_frames: Optional[int] = None,
fps: Optional[int] = None,
backend: str = "opencv",
) -> np.array:
"""
Loads `video` to a numpy array.
Args:
video (`str` or `VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
backend (`str`, *optional*, defaults to `"opencv"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
Returns:
Tuple[`np.array`, Dict]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- Metadata dictionary.
"""
video, metadata = load_video(video, num_frames, fps=fps, backend=backend)
return video, metadata
def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
"""
Post-process the output of a vlm to decode the text.

View File

@ -236,55 +236,6 @@ And who is that?<|im_end|>
"""
self.assertEqual(rendered, expected_rendered)
# Override as AriaImageProcessor doesn't accept `do_rescale`
def test_image_chat_template_accepts_processing_kwargs(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
},
]
]
formatted_prompt_tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
padding="max_length",
max_length=50,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 50)
formatted_prompt_tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
truncation=True,
max_length=5,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 5)
# Now test the ability to return dict
messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
out_dict = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
max_image_size=980,
return_tensors="np",
)
self.assertListEqual(list(out_dict[self.images_input_name].shape), [1, 3, 980, 980])
# Override as AriaProcessor needs image tokens in prompts
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:

View File

@ -79,11 +79,6 @@ class AyaVisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
# todo: yoni, fix this test
@unittest.skip("Chat template has long system prompt")
def test_chat_template_accepts_processing_kwargs(self, **kwargs):
pass
# Override as AyaVisionProcessor needs image tokens in prompts
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:

View File

@ -86,67 +86,3 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor = LlavaProcessor.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__)
def test_chat_template(self):
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
expected_prompt = "USER: <image>\nWhat is shown in this image? ASSISTANT:"
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
def test_chat_template_dict(self):
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
expected_output = [[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, 9047, 13566, 29901]] # fmt: skip
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# add image URL for return dict
messages[0]["content"][0] = {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
out_dict_with_image = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True
)
self.assertListEqual(list(out_dict_with_image.keys()), ["input_ids", "attention_mask", "pixel_values"])
def test_chat_template_with_continue_final_message(self):
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
expected_prompt = "USER: <image>\nDescribe this image. ASSISTANT: There is a dog and"
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Describe this image."},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "There is a dog and"},
],
},
]
prompt = processor.apply_chat_template(messages, continue_final_message=True)
self.assertEqual(expected_prompt, prompt)

View File

@ -78,23 +78,6 @@ class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
def test_chat_template(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
expected_prompt = "USER: <image>\nWhat is shown in this image? ASSISTANT:"
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
def test_image_token_filling(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
processor.patch_size = 14

View File

@ -18,7 +18,7 @@ import tempfile
import unittest
from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextVideoProcessor
from transformers.testing_utils import require_av, require_torch, require_vision
from transformers.testing_utils import require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
@ -28,7 +28,7 @@ if is_vision_available():
from transformers import LlavaNextImageProcessor, LlavaNextVideoImageProcessor
if is_torch_available:
import torch
pass
@require_vision
@ -90,79 +90,3 @@ class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
def test_chat_template(self):
processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
expected_prompt = "USER: <image>\nWhat is shown in this image? ASSISTANT:"
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
@require_av
def test_chat_template_dict(self):
processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
messages = [
{
"role": "user",
"content": [
{"type": "video"},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
formatted_prompt_tokenized = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_tensors=None
)
expected_output = [[1, 3148, 1001, 29901, 29871, 32000, 13, 5618, 338, 4318, 297, 445, 4863, 29973, 319, 1799, 9047, 13566, 29901]] # fmt: skip
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# add image URL for return dict
messages[0]["content"][0] = {
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
}
out_dict_with_video = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True
)
self.assertListEqual(list(out_dict_with_video.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
@require_torch
@require_av
def test_chat_template_dict_torch(self):
processor = AutoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
out_dict_tensors = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))

View File

@ -16,7 +16,7 @@ import shutil
import tempfile
import unittest
from transformers.testing_utils import require_av, require_vision
from transformers.testing_utils import require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
@ -93,50 +93,3 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
# so we check if the same template is loaded
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
def test_chat_template(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
expected_prompt = "<|im_start|>user <image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
@require_av
def test_chat_template_dict(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
messages = [
{
"role": "user",
"content": [
{"type": "video"},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
expected_output = [[151644, 872, 220, 151647, 198, 3838, 374, 6839, 304, 419, 2766, 30, 151645, 151644, 77091, 198]] # fmt: skip
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# add image URL for return dict
messages[0]["content"][0] = {
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
}
out_dict_with_video = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True
)
self.assertListEqual(list(out_dict_with_video.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])

View File

@ -62,77 +62,6 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_chat_template_accepts_processing_kwargs(self):
# override to use slow image processor to return numpy arrays
processor = self.processor_class.from_pretrained(self.tmpdirname, use_fast=False)
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
},
]
]
formatted_prompt_tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
padding="max_length",
truncation=True,
max_length=50,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 50)
formatted_prompt_tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
truncation=True,
max_length=5,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 5)
# Now test the ability to return dict
messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
out_dict = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
do_rescale=True,
rescale_factor=-1,
return_tensors="np",
)
self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0)
def test_chat_template(self):
processor = self.processor_class.from_pretrained(self.tmpdirname, use_fast=False)
expected_prompt = "<s>[SYSTEM_PROMPT][/SYSTEM_PROMPT][INST][IMG]What is shown in this image?[/INST]"
messages = [
{
"role": "system",
"content": "",
},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
def test_image_token_filling(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
# Important to check with non square image

View File

@ -51,22 +51,6 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_chat_template(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
expected_prompt = "<s>[INST][IMG]What is shown in this image?[/INST]"
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
def test_image_token_filling(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
# Important to check with non square image

View File

@ -17,12 +17,13 @@ import shutil
import tempfile
import unittest
import numpy as np
import pytest
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, Qwen2Tokenizer
from transformers.testing_utils import require_av, require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
@ -30,6 +31,9 @@ from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import Qwen2_5_VLProcessor, Qwen2VLImageProcessor
if is_torch_available():
import torch
@require_vision
@require_torch
@ -119,101 +123,97 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
def test_image_chat_template_single(self):
@require_torch
def _test_apply_chat_template(
self,
modality: str,
batch_size: int,
return_tensors: str,
input_name: str,
processor_name: str,
input_data: list[str],
):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
messages = [
if processor_name not in self.processor_class.attributes:
self.skipTest(f"{processor_name} attribute not present in {self.processor_class}")
batch_messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
"content": [{"type": "text", "text": "Describe this."}],
},
]
]
] * batch_size
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 1)
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# Now test the ability to return dict
messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertTrue(self.images_input_name in out_dict)
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 1)
self.assertEqual(len(out_dict["attention_mask"]), 1)
self.assertEqual(len(out_dict[self.images_input_name]), 71280)
def test_image_chat_template_batched(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
batched_messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
},
],
[
{
"role": "user",
"content": [
{"type": "text", "text": "What do you see?"},
],
},
],
]
formatted_prompt = processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 2)
# Test that jinja can be applied
formatted_prompt = processor.apply_chat_template(batch_messages, add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), batch_size)
# Test that tokenizing with template and directly with `self.tokenizer` gives same output
formatted_prompt_tokenized = processor.apply_chat_template(
batched_messages, add_generation_prompt=True, tokenize=True, padding=True
batch_messages, add_generation_prompt=True, tokenize=True, return_tensors=return_tensors
)
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None, padding=True).input_ids
self.assertListEqual(expected_output, formatted_prompt_tokenized)
add_special_tokens = True
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
add_special_tokens = False
tok_output = processor.tokenizer(
formatted_prompt, return_tensors=return_tensors, add_special_tokens=add_special_tokens
)
expected_output = tok_output.input_ids
self.assertListEqual(expected_output.tolist(), formatted_prompt_tokenized.tolist())
# Test that kwargs passed to processor's `__call__` are actually used
tokenized_prompt_100 = processor.apply_chat_template(
batch_messages,
add_generation_prompt=True,
tokenize=True,
padding="max_length",
truncation=True,
return_tensors=return_tensors,
max_length=100,
)
self.assertEqual(len(tokenized_prompt_100[0]), 100)
# Test that `return_dict=True` returns text related inputs in the dict
out_dict_text = processor.apply_chat_template(
batch_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors=return_tensors,
)
self.assertTrue(all(key in out_dict_text for key in ["input_ids", "attention_mask"]))
self.assertEqual(len(out_dict_text["input_ids"]), batch_size)
self.assertEqual(len(out_dict_text["attention_mask"]), batch_size)
# Test that with modality URLs and `return_dict=True`, we get modality inputs in the dict
for idx, url in enumerate(input_data[:batch_size]):
batch_messages[idx][0]["content"] = [batch_messages[idx][0]["content"][0], {"type": modality, "url": url}]
out_dict = processor.apply_chat_template(
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
batch_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors=return_tensors,
num_frames=4, # by default no more than 4 frames, otherwise too slow
)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
input_name = getattr(self, input_name)
self.assertTrue(input_name in out_dict)
self.assertEqual(len(out_dict["input_ids"]), batch_size)
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
self.assertEqual(len(out_dict[input_name]), batch_size * 19200)
# Now test the ability to return dict
batched_messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
batched_messages[1][0]["content"].append(
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}
)
out_dict = processor.apply_chat_template(
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
)
self.assertTrue(self.images_input_name in out_dict)
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 2)
self.assertEqual(len(out_dict["attention_mask"]), 2)
self.assertEqual(len(out_dict[self.images_input_name]), 90480)
return_tensor_to_type = {"pt": torch.Tensor, "np": np.ndarray, None: list}
for k in out_dict:
self.assertIsInstance(out_dict[k], return_tensor_to_type[return_tensors])
@require_av
def test_chat_template_video(self):
def test_apply_chat_template_video_frame_sampling(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
@ -331,52 +331,7 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertEqual(inputs[self.images_input_name].shape[0], 800)
@require_av
def test_chat_template_video_custom_sampling(self):
"""
Tests that models can pass their custom callables to sample video indices.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def dummy_sample_indices_fn(metadata, **fn_kwargs):
# sample only the first two frame always
return [0, 1]
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
sample_indices_fn=dummy_sample_indices_fn,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 14400)
@require_av
def test_chat_template_video_special_processing(self):
def test_apply_chat_template_video_special_processing(self):
"""
Tests that models can use their own preprocessing to preprocess conversations.
"""
@ -433,6 +388,7 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="np",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)

View File

@ -54,7 +54,7 @@ class Qwen2AudioProcessorTest(ProcessorTesterMixin, unittest.TestCase):
@staticmethod
def prepare_processor_dict():
return {
"chat_template": "{% set audio_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if 'audio' in content or 'audio_url' in content or message['type'] == 'audio' %}{% set audio_count.value = audio_count.value + 1 %}Audio {{ audio_count.value }}: <|audio_bos|><|AUDIO|><|audio_eos|>\n{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}",
"chat_template": "{% set audio_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if 'audio' in content or 'audio_url' in content or content['type'] == 'audio' %}{% set audio_count.value = audio_count.value + 1 %}Audio {{ audio_count.value }}: <|audio_bos|><|AUDIO|><|audio_eos|>\n{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}",
}
# Override as Qwen2AudioProcessor needs audio tokens in prompts
@ -159,29 +159,3 @@ class Qwen2AudioProcessorTest(ProcessorTesterMixin, unittest.TestCase):
formatted_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
def test_chat_template_with_continue_final_message(self):
processor = AutoProcessor.from_pretrained(self.checkpoint)
expected_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of " # fmt: skip
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
"audio": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3",
},
{"type": "text", "text": "What's that sound?"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is the sound of "}],
},
]
prompt = processor.apply_chat_template(messages, continue_final_message=True)
self.assertEqual(expected_prompt, prompt)

View File

@ -17,12 +17,13 @@ import shutil
import tempfile
import unittest
import numpy as np
import pytest
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, Qwen2Tokenizer
from transformers.testing_utils import require_av, require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
@ -30,6 +31,9 @@ from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import Qwen2VLImageProcessor, Qwen2VLProcessor
if is_torch_available():
import torch
@require_vision
@require_torch
@ -116,101 +120,97 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
def test_image_chat_template_single(self):
@require_torch
def _test_apply_chat_template(
self,
modality: str,
batch_size: int,
return_tensors: str,
input_name: str,
processor_name: str,
input_data: list[str],
):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
messages = [
if processor_name not in self.processor_class.attributes:
self.skipTest(f"{processor_name} attribute not present in {self.processor_class}")
batch_messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
"content": [{"type": "text", "text": "Describe this."}],
},
]
]
] * batch_size
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 1)
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# Now test the ability to return dict
messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertTrue(self.images_input_name in out_dict)
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 1)
self.assertEqual(len(out_dict["attention_mask"]), 1)
self.assertEqual(len(out_dict[self.images_input_name]), 71280)
def test_image_chat_template_batched(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
batched_messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
},
],
[
{
"role": "user",
"content": [
{"type": "text", "text": "What do you see?"},
],
},
],
]
formatted_prompt = processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 2)
# Test that jinja can be applied
formatted_prompt = processor.apply_chat_template(batch_messages, add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), batch_size)
# Test that tokenizing with template and directly with `self.tokenizer` gives same output
formatted_prompt_tokenized = processor.apply_chat_template(
batched_messages, add_generation_prompt=True, tokenize=True, padding=True
batch_messages, add_generation_prompt=True, tokenize=True, return_tensors=return_tensors
)
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None, padding=True).input_ids
self.assertListEqual(expected_output, formatted_prompt_tokenized)
add_special_tokens = True
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
add_special_tokens = False
tok_output = processor.tokenizer(
formatted_prompt, return_tensors=return_tensors, add_special_tokens=add_special_tokens
)
expected_output = tok_output.input_ids
self.assertListEqual(expected_output.tolist(), formatted_prompt_tokenized.tolist())
# Test that kwargs passed to processor's `__call__` are actually used
tokenized_prompt_100 = processor.apply_chat_template(
batch_messages,
add_generation_prompt=True,
tokenize=True,
padding="max_length",
truncation=True,
return_tensors=return_tensors,
max_length=100,
)
self.assertEqual(len(tokenized_prompt_100[0]), 100)
# Test that `return_dict=True` returns text related inputs in the dict
out_dict_text = processor.apply_chat_template(
batch_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors=return_tensors,
)
self.assertTrue(all(key in out_dict_text for key in ["input_ids", "attention_mask"]))
self.assertEqual(len(out_dict_text["input_ids"]), batch_size)
self.assertEqual(len(out_dict_text["attention_mask"]), batch_size)
# Test that with modality URLs and `return_dict=True`, we get modality inputs in the dict
for idx, url in enumerate(input_data[:batch_size]):
batch_messages[idx][0]["content"] = [batch_messages[idx][0]["content"][0], {"type": modality, "url": url}]
out_dict = processor.apply_chat_template(
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
batch_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors=return_tensors,
num_frames=4, # by default no more than 4 frames, otherwise too slow
)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
input_name = getattr(self, input_name)
self.assertTrue(input_name in out_dict)
self.assertEqual(len(out_dict["input_ids"]), batch_size)
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
self.assertEqual(len(out_dict[input_name]), batch_size * 19200)
# Now test the ability to return dict
batched_messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
batched_messages[1][0]["content"].append(
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}
)
out_dict = processor.apply_chat_template(
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
)
self.assertTrue(self.images_input_name in out_dict)
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 2)
self.assertEqual(len(out_dict["attention_mask"]), 2)
self.assertEqual(len(out_dict[self.images_input_name]), 90480)
return_tensor_to_type = {"pt": torch.Tensor, "np": np.ndarray, None: list}
for k in out_dict:
self.assertIsInstance(out_dict[k], return_tensor_to_type[return_tensors])
@require_av
def test_chat_template_video(self):
def test_apply_chat_template_video_frame_sampling(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
@ -312,52 +312,7 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 71280)
@require_av
def test_chat_template_video_custom_sampling(self):
"""
Tests that models can pass their custom callables to sample video indices.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def dummy_sample_indices_fn(metadata, **fn_kwargs):
# sample only the first two frame always
return [0, 1]
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
sample_indices_fn=dummy_sample_indices_fn,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 14400)
@require_av
def test_chat_template_video_special_processing(self):
def test_apply_chat_template_video_special_processing(self):
"""
Tests that models can use their own preprocessing to preprocess conversations.
"""
@ -414,6 +369,7 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="np",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)

View File

@ -162,29 +162,14 @@ class ShieldGemma2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.skipTest("Processor has no chat template")
images = self.prepare_image_inputs(batch_size=2)
print(images)
processed_inputs = processor(images=images)
self.assertEqual(len(processed_inputs[self.text_input_name]), 6)
self.assertEqual(len(processed_inputs[self.images_input_name]), 6)
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@parameterized.expand([(1, "np"), (1, "pt"), (2, "np"), (2, "pt")])
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_image_chat_template_accepts_processing_kwargs(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_image_chat_template_batched(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_image_chat_template_dict_torch(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_image_chat_template_single(self):
def test_apply_chat_template_image(self, batch_size: int, return_tensors: str):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2

View File

@ -368,12 +368,12 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
)
self.assertEqual(rendered, expected_rendered)
@unittest.skip(reason="Broken from common. Fixing TODO @zucchini-nlp @molbap")
def test_chat_template_video_special_processing(self):
@unittest.skip(reason="SmolVLM replaced `type=video` with `type=image` in chat templates")
def test_apply_chat_template_video_special_processing(self):
pass
@require_av
def test_chat_template_video(self):
def test_apply_chat_template_video_frame_sampling(self):
# overriden because SmolVLM has special preprocessing for videos
processor = self.get_processor()
if processor.chat_template is None:
@ -401,11 +401,12 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenize=True,
return_dict=True,
num_frames=num_frames,
return_tensors="np",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
# SmolVLM doesn't sample `num_frames` exactly, by uses other sampling method
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 10)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 3)
# Load with `video_fps` arg
video_fps = 1
@ -415,6 +416,7 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenize=True,
return_dict=True,
video_fps=video_fps,
return_tensors="np",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)

View File

@ -1,41 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available
if is_vision_available():
from transformers import AutoProcessor
@require_vision
class LlavaProcessorTest(unittest.TestCase):
def test_chat_template(self):
processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf")
expected_prompt = "###Human: <image>\nWhat is shown in this image?###Assistant:"
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)

View File

@ -22,6 +22,7 @@ from typing import Optional
import numpy as np
from huggingface_hub import hf_hub_download
from parameterized import parameterized
from transformers.models.auto.processing_auto import processor_class_from_name
from transformers.processing_utils import Unpack
@ -44,6 +45,22 @@ if is_torch_available():
import torch
MODALITY_INPUT_DATA = {
"images": [
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/val2017/000000039769.jpg",
],
"videos": [
"https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
["https://www.ilankelman.org/stopsigns/australia.jpg", "https://www.ilankelman.org/stopsigns/australia.jpg"],
],
"audio": [
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3",
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav",
],
}
def prepare_image_inputs():
"""This function prepares a list of PIL images"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
@ -729,7 +746,7 @@ class ProcessorTesterMixin:
)
def test_chat_template_save_loading(self):
processor = self.get_processor()
processor = self.processor_class.from_pretrained(self.tmpdirname)
signature = inspect.signature(processor.__init__)
if "chat_template" not in {*signature.parameters.keys()}:
self.skipTest("Processor doesn't accept chat templates at input")
@ -756,210 +773,133 @@ class ProcessorTesterMixin:
# the reloaded tokenizer should get the chat template as well
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
def test_image_chat_template_single(self):
@require_torch
def _test_apply_chat_template(
self,
modality: str,
batch_size: int,
return_tensors: str,
input_name: str,
processor_name: str,
input_data: list[str],
):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
if processor_name not in self.processor_class.attributes:
self.skipTest(f"{processor_name} attribute not present in {self.processor_class}")
messages = [
# some models have only Fast image processor
if getattr(processor, processor_name).__class__.__name__.endswith("Fast"):
return_tensors = "pt"
batch_messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
"content": [{"type": "text", "text": "Describe this."}],
},
]
]
] * batch_size
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 1)
# Test that jinja can be applied
formatted_prompt = processor.apply_chat_template(batch_messages, add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), batch_size)
# Test that tokenizing with template and directly with `self.tokenizer` gives same output
formatted_prompt_tokenized = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_tensors=None
batch_messages, add_generation_prompt=True, tokenize=True, return_tensors=return_tensors
)
add_special_tokens = True
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
add_special_tokens = False
expected_output = processor.tokenizer(
formatted_prompt, return_tensors=None, add_special_tokens=add_special_tokens
).input_ids
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertTrue(all(key in out_dict for key in ["input_ids", "attention_mask"]))
# Now test the ability to return dict
messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
tok_output = processor.tokenizer(
formatted_prompt, return_tensors=return_tensors, add_special_tokens=add_special_tokens
)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertTrue(self.images_input_name in out_dict)
expected_output = tok_output.input_ids
self.assertListEqual(expected_output.tolist(), formatted_prompt_tokenized.tolist())
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 1)
self.assertEqual(len(out_dict["attention_mask"]), 1)
self.assertEqual(len(out_dict[self.images_input_name]), 1)
def test_image_chat_template_batched(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
batched_messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
},
],
[
{
"role": "user",
"content": [
{"type": "text", "text": "What do you see?"},
],
},
],
]
formatted_prompt = processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 2)
formatted_prompt_tokenized = processor.apply_chat_template(
batched_messages, add_generation_prompt=True, tokenize=True, padding=True, return_tensors=None
)
add_special_tokens = True
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
add_special_tokens = False
expected_output = processor.tokenizer(
formatted_prompt,
return_tensors=None,
padding=True,
add_special_tokens=add_special_tokens,
).input_ids
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(
batched_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
padding=True,
)
self.assertTrue(all(key in out_dict for key in ["input_ids", "attention_mask"]))
# Now test the ability to return dict
batched_messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
batched_messages[1][0]["content"].append(
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}
)
out_dict = processor.apply_chat_template(
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
)
self.assertTrue(self.images_input_name in out_dict)
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 2)
self.assertEqual(len(out_dict["attention_mask"]), 2)
self.assertEqual(len(out_dict[self.images_input_name]), 2)
def test_image_chat_template_accepts_processing_kwargs(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
},
]
]
formatted_prompt_tokenized = processor.apply_chat_template(
messages,
# Test that kwargs passed to processor's `__call__` are actually used
tokenized_prompt_100 = processor.apply_chat_template(
batch_messages,
add_generation_prompt=True,
tokenize=True,
padding="max_length",
truncation=True,
max_length=50,
return_tensors=return_tensors,
max_length=100,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 50)
self.assertEqual(len(tokenized_prompt_100[0]), 100)
formatted_prompt_tokenized = processor.apply_chat_template(
messages,
# Test that `return_dict=True` returns text related inputs in the dict
out_dict_text = processor.apply_chat_template(
batch_messages,
add_generation_prompt=True,
tokenize=True,
truncation=True,
max_length=5,
return_dict=True,
return_tensors=return_tensors,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 5)
self.assertTrue(all(key in out_dict_text for key in ["input_ids", "attention_mask"]))
self.assertEqual(len(out_dict_text["input_ids"]), batch_size)
self.assertEqual(len(out_dict_text["attention_mask"]), batch_size)
# Test that with modality URLs and `return_dict=True`, we get modality inputs in the dict
for idx, url in enumerate(input_data[:batch_size]):
batch_messages[idx][0]["content"] = [batch_messages[idx][0]["content"][0], {"type": modality, "url": url}]
# Now test the ability to return dict
messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
out_dict = processor.apply_chat_template(
messages,
batch_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
do_rescale=True,
rescale_factor=-1,
return_tensors="np",
return_tensors=return_tensors,
num_frames=4, # by default no more than 4 frames, otherwise too slow
)
self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0)
input_name = getattr(self, input_name)
self.assertTrue(input_name in out_dict)
self.assertEqual(len(out_dict["input_ids"]), batch_size)
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
self.assertEqual(len(out_dict[input_name]), batch_size)
@require_torch
def test_image_chat_template_dict_torch(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
return_tensor_to_type = {"pt": torch.Tensor, "np": np.ndarray, None: list}
for k in out_dict:
self.assertIsInstance(out_dict[k], return_tensor_to_type[return_tensors])
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
out_dict_tensors = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertTrue(self.images_input_name in out_dict_tensors)
for k in out_dict_tensors:
self.assertIsInstance(out_dict_tensors[k], torch.Tensor)
# Test continue from final message
assistant_message = {
"role": "assistant",
"content": [{"type": "text", "text": "It is the sound of"}],
}
for idx, url in enumerate(input_data[:batch_size]):
batch_messages[idx] = batch_messages[idx] + [assistant_message]
continue_prompt = processor.apply_chat_template(batch_messages, continue_final_message=True, tokenize=False)
for prompt in continue_prompt:
self.assertTrue(prompt.endswith("It is the sound of")) # no `eos` token at the end
@require_av
def test_chat_template_video(self):
@parameterized.expand([(1, "np"), (1, "pt"), (2, "np"), (2, "pt")])
def test_apply_chat_template_audio(self, batch_size: int, return_tensors: str):
self._test_apply_chat_template(
"audio", batch_size, return_tensors, "audio_input_name", "feature_extracttor", MODALITY_INPUT_DATA["audio"]
)
@require_librosa
@parameterized.expand([(1, "np"), (1, "pt"), (2, "np"), (2, "pt")])
def test_apply_chat_template_video(self, batch_size: int, return_tensors: str):
self._test_apply_chat_template(
"video", batch_size, return_tensors, "videos_input_name", "video_processor", MODALITY_INPUT_DATA["videos"]
)
@parameterized.expand([(1, "np"), (1, "pt"), (2, "np"), (2, "pt")])
def test_apply_chat_template_image(self, batch_size: int, return_tensors: str):
self._test_apply_chat_template(
"image", batch_size, return_tensors, "images_input_name", "image_processor", MODALITY_INPUT_DATA["images"]
)
def test_apply_chat_template_video_frame_sampling(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
@ -975,37 +915,16 @@ class ProcessorTesterMixin:
{
"role": "user",
"content": [
{"type": "video"},
{
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 1)
formatted_prompt_tokenized = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_tensors=None
)
add_special_tokens = True
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
add_special_tokens = False
expected_output = processor.tokenizer(
formatted_prompt,
return_tensors=None,
add_special_tokens=add_special_tokens,
).input_ids
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertTrue(all(key in out_dict for key in ["input_ids", "attention_mask"]))
# Add video URL for return dict and load with `num_frames` arg
messages[0][0]["content"][0] = {
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
}
num_frames = 3
out_dict_with_video = processor.apply_chat_template(
messages,
@ -1013,6 +932,7 @@ class ProcessorTesterMixin:
tokenize=True,
return_dict=True,
num_frames=num_frames,
return_tensors="np",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
@ -1026,6 +946,7 @@ class ProcessorTesterMixin:
tokenize=True,
return_dict=True,
video_fps=video_fps,
return_tensors="np",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
@ -1073,53 +994,7 @@ class ProcessorTesterMixin:
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
@require_av
def test_chat_template_video_custom_sampling(self):
"""
Tests that models can pass their custom callables to sample video indices.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def dummy_sample_indices_fn(metadata, **fn_kwargs):
# sample only the first two frame always
return [0, 1]
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
sample_indices_fn=dummy_sample_indices_fn,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
@require_av
def test_chat_template_video_special_processing(self):
def test_apply_chat_template_video_special_processing(self):
"""
Tests that models can use their own preprocessing to preprocess conversations.
"""
@ -1176,6 +1051,7 @@ class ProcessorTesterMixin:
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="np",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
@ -1187,7 +1063,7 @@ class ProcessorTesterMixin:
@require_librosa
@require_av
def test_audio_chat_template_from_video(self):
def test_chat_template_audio_from_video(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
@ -1241,124 +1117,10 @@ class ProcessorTesterMixin:
load_audio_from_video=True,
)
self.assertTrue(self.audio_input_name in out_dict)
self.assertTrue(self.video_input_name in out_dict)
self.assertTrue(self.videos_input_name in out_dict)
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 1) # batch-size=1
self.assertEqual(len(out_dict["attention_mask"]), 1) # batch-size=1
self.assertEqual(len(out_dict[self.audio_input_name]), 2) # 2 audios in the conversation
self.assertEqual(len(out_dict[self.video_input_name]), 1) # 1 video in the conversation
@require_librosa
def test_audio_chat_template_single(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
},
{"type": "text", "text": "What's that sound?"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is the sound of glass shattering."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
},
{"type": "text", "text": "How about this one?"},
],
},
]
formatted_prompt = processor.apply_chat_template([messages], add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 1) # batch size=1
formatted_prompt_tokenized = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_tensors=None
)
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
self.assertListEqual(expected_output, formatted_prompt_tokenized)
messages[1]["content"][0]["audio"] = (
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
)
messages[3]["content"][0]["audio"] = (
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertTrue(self.audio_input_name in out_dict)
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 1) # batch-size=1
self.assertEqual(len(out_dict["attention_mask"]), 1) # batch-size=1
self.assertEqual(len(out_dict[self.audio_input_name]), 2) # 2 audios in the conversation
@require_torch
@require_librosa
def test_audio_chat_template_dict_torch(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
"audio": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3",
},
{"type": "text", "text": "What's that sound?"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is the sound of glass shattering."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
"audio": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav",
},
{"type": "text", "text": "How about this one?"},
],
},
]
out_dict_tensors = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertTrue(self.audio_input_name in out_dict_tensors)
for k in out_dict_tensors:
self.assertIsInstance(out_dict_tensors[k], torch.Tensor)
self.assertEqual(len(out_dict[self.videos_input_name]), 1) # 1 video in the conversation