[video processors] support frame sampling within processors (#38105)

* apply updates smolVLM (still needs workaround for chat template)

* add other models

* dump qwen omni for now, come back later

* port qwen omni from their impl

* wait, all qwens sample videos in same way!

* clean up

* make smolvlm backwards compatible and fix padding

* dix some tests

* fox smolvlm tests

* more clean up and test fixing

* delete unused arg

* fix

* address comments

* style

* fix test
This commit is contained in:
Raushan Turganbay 2025-06-12 11:34:30 +02:00 committed by GitHub
parent 887054c714
commit 27459025b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 864 additions and 795 deletions

View File

@ -35,7 +35,7 @@ from ...utils import (
)
from ...utils.import_utils import requires
from ...video_processing_utils import BaseVideoProcessor
from ...video_utils import group_videos_by_shape, reorder_videos
from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos
if is_vision_available():
@ -66,6 +66,7 @@ class InstructBlipVideoVideoProcessor(BaseVideoProcessor):
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
valid_kwargs = InstructBlipVideoVideoProcessorInitKwargs
model_input_names = ["pixel_values"]
@ -75,6 +76,7 @@ class InstructBlipVideoVideoProcessor(BaseVideoProcessor):
def _preprocess(
self,
videos: List["torch.Tensor"],
video_metadata: Union[List[VideoMetadata], List[dict]],
do_convert_rgb: bool,
do_resize: bool,
size: SizeDict,
@ -86,10 +88,18 @@ class InstructBlipVideoVideoProcessor(BaseVideoProcessor):
do_pad: bool,
rescale_factor: float,
do_normalize: bool,
do_sample_frames: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
fps: Optional[int] = None,
num_frames: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
if do_sample_frames:
videos = [
self.sample_frames(video, metadata, num_frames, fps) for video, metadata in zip(videos, video_metadata)
]
# Group videos by size for batched resizing
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
resized_videos_grouped = {}

View File

@ -21,7 +21,7 @@ from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput, concatenate_list, make_flat_list_of_images
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...video_utils import VideoInput, VideoMetadata, load_video, make_batched_videos
from ...video_utils import VideoInput, make_batched_videos
class InternVLImagesKwargs(ImagesKwargs, total=False):
@ -290,32 +290,6 @@ class InternVLProcessor(ProcessorMixin):
return MultiModalData(**vision_data)
def sample_indices_fn(
self, metadata: VideoMetadata, num_frames: Optional[int] = None, initial_shift: Union[bool, float, int] = True
):
"""
The function to generate indices of frames to sample from a video.
Args:
metadata (`VideoMetadata`):
`VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If None, all frames are sampled.
initial_shift (`bool`, `float` or `int`, defaults to `0`):
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.
Returns:
`np.ndarray`: Array of frame indices to sample.
"""
num_frames = num_frames if num_frames is not None else metadata.total_num_frames
if initial_shift is True:
initial_shift = metadata.total_num_frames / num_frames / 2
indices = np.arange(initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames).astype(
int
)
return indices
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
@ -336,39 +310,5 @@ class InternVLProcessor(ProcessorMixin):
image_processor_input_names = self.image_processor.model_input_names
return list(tokenizer_input_names) + list(image_processor_input_names)
# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
def _load_video_for_model(
self,
video: Union[str, "VideoInput"],
num_frames: Optional[int],
backend: str = "pyav",
initial_shift: bool = True,
**kwargs,
) -> np.array:
"""
Loads `video` to a numpy array.
Args:
video (`str` or `VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
backend (`str`, *optional*, defaults to `"pyav"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav".
initial_shift (`bool`, *optional*, defaults to `True`):
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.
Returns:
Tuple[`np.array`, Dict]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- Metadata dictionary.
"""
def sample_indices_fn_func(metadata, **fn_kwargs):
return self.sample_indices_fn(metadata, num_frames=num_frames, initial_shift=initial_shift, **fn_kwargs)
video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func)
return video, metadata
__all__ = ["InternVLProcessor"]

View File

@ -14,25 +14,43 @@
# limitations under the License.
"""Fast Video processor class for InternVL."""
from typing import List, Optional, Union
from ...image_processing_utils import BatchFeature
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
SizeDict,
)
from ...processing_utils import Unpack, VideosKwargs
from ...utils import (
TensorType,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
is_vision_available,
)
from ...utils.import_utils import requires
from ...video_processing_utils import (
BaseVideoProcessor,
)
from ...video_processing_utils import BaseVideoProcessor
from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
if is_torch_available():
import torch
if is_vision_available():
from ...image_utils import PILImageResampling
class InternVLVideoProcessorInitKwargs(VideosKwargs): ...
class InternVLVideoProcessorInitKwargs(VideosKwargs):
initial_shift: Union[bool, float, int]
@requires(backends=("torchvision",))
@ -45,11 +63,128 @@ class InternVLVideoProcessor(BaseVideoProcessor):
do_rescale = True
do_normalize = True
do_convert_rgb = True
initial_shift = True
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
valid_kwargs = InternVLVideoProcessorInitKwargs
model_input_names = ["pixel_values_videos"]
def __init__(self, **kwargs: Unpack[InternVLVideoProcessorInitKwargs]):
super().__init__(**kwargs)
def sample_frames(
self,
video: "torch.Tensor",
metadata: Optional[Union[VideoMetadata, dict]] = None,
num_frames: Optional[int] = None,
fps: Optional[int] = None,
initial_shift: Optional[Union[bool, float, int]] = None,
):
"""
Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames.
If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames`
and `fps` are mutually exclusive.
Args:
video (`torch.Tensor`):
Video that need to be sampled.
metadata (`VideoMetadata`, *optional*):
Metadata of the video containing information about total duration, fps and total number of frames.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
fps (`int`, *optional*):
Target frames to sample per second. Defaults to `self.fps`.
initial_shift (`bool`, `float` or `int`, defaults to `self.initial_shift`):
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.
Returns:
torch.Tensor:
Sampled video frames.
"""
num_frames = num_frames if num_frames is not None else self.num_frames
initial_shift = initial_shift if initial_shift is not None else self.initial_shift
total_num_frames = video.shape[0]
# If num_frames is not given but fps is, calculate num_frames from fps
if num_frames is None and fps is not None:
if metadata is None:
raise ValueError(
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
"Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
)
num_frames = int(total_num_frames / metadata["fps"] * fps)
if initial_shift is True:
initial_shift = total_num_frames / num_frames / 2
if num_frames > total_num_frames:
raise ValueError(
f"Video can't be sampled. The `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
)
indices = torch.arange(initial_shift, total_num_frames, total_num_frames / num_frames).int()
video = video[indices].contiguous()
return video
def _preprocess(
self,
videos: List["torch.Tensor"],
video_metadata: Union[List[VideoMetadata], List[dict]],
do_convert_rgb: bool,
do_resize: bool,
size: SizeDict,
size_divisor: Optional[int],
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
do_pad: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_sample_frames: Optional[bool] = None,
fps: Optional[int] = None,
num_frames: Optional[int] = None,
initial_shift: Optional[Union[bool, float, int]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
if do_sample_frames:
# Sample video frames
videos = [
self.sample_frames(video, metadata, fps=fps, num_frames=num_frames, initial_shift=initial_shift)
for video, metadata in zip(videos, video_metadata)
]
# Group videos by size for batched resizing
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
resized_videos_grouped = {}
for shape, stacked_videos in grouped_videos.items():
if do_convert_rgb:
stacked_videos = self.convert_to_rgb(stacked_videos)
if do_resize:
stacked_videos = self.resize(
stacked_videos, size=size, size_divisor=size_divisor, interpolation=interpolation
)
resized_videos_grouped[shape] = stacked_videos
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
# Group videos by size for further processing
# Needed in case do_resize is False, or resize returns videos with different sizes
grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
processed_videos_grouped = {}
for shape, stacked_videos in grouped_videos.items():
if do_center_crop:
stacked_videos = self.center_crop(stacked_videos, crop_size)
# Fused rescale and normalize
stacked_videos = self.rescale_and_normalize(
stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_videos_grouped[shape] = stacked_videos
processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos
return BatchFeature(data={"pixel_values_videos": processed_videos}, tensor_type=return_tensors)
__all__ = ["InternVLVideoProcessor"]

View File

@ -46,6 +46,7 @@ class LlavaNextVideoVideoProcessor(BaseVideoProcessor):
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
valid_kwargs = LlavaNextVideoFastVideoProcessorInitKwargs
model_input_names = ["pixel_values_videos"]

View File

@ -47,6 +47,7 @@ class LlavaOnevisionVideoProcessor(BaseVideoProcessor):
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
valid_kwargs = LlavaOnevisionFastVideoProcessorInitKwargs
model_input_names = ["pixel_values_videos"]

View File

@ -154,7 +154,7 @@ class Qwen2_5OmniProcessor(ProcessorMixin):
seconds_per_chunk = output_kwargs["videos_kwargs"].pop("seconds_per_chunk")
position_id_per_seconds = output_kwargs["videos_kwargs"].pop("position_id_per_seconds")
use_audio_in_video = output_kwargs["videos_kwargs"].pop("use_audio_in_video")
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
if audio is not None:
output_kwargs["audio_kwargs"]["padding"] = "max_length" # Support "max_length" padding only here

View File

@ -928,7 +928,6 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
"padding": False,
"return_mm_token_type_ids": False,
},
"videos_kwargs": {"fps": 2.0},
}
@ -1013,9 +1012,7 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor):
image_grid_thw = image_inputs["image_grid_thw"]
if videos is not None:
# pop fps in advance for passing kwargs validation
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]

View File

@ -54,7 +54,6 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
"padding": False,
"return_mm_token_type_ids": False,
},
"videos_kwargs": {"fps": 2.0},
}
@ -151,9 +150,7 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
image_grid_thw = image_inputs["image_grid_thw"]
if videos is not None:
# pop fps in advance for passing kwargs validation
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]

View File

@ -19,6 +19,7 @@
# limitations under the License.
"""video processor class for Qwen2-VL."""
import math
from typing import List, Optional, Union
from ...image_processing_utils import (
@ -45,7 +46,7 @@ from ...video_processing_utils import (
BASE_VIDEO_PROCESSOR_DOCSTRING,
BaseVideoProcessor,
)
from ...video_utils import group_videos_by_shape, reorder_videos
from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos
if is_vision_available():
@ -69,6 +70,8 @@ class Qwen2VLVideoProcessorInitKwargs(VideosKwargs):
patch_size: Optional[int]
temporal_patch_size: Optional[int]
merge_size: Optional[int]
min_frames: Optional[int]
max_frames: Optional[int]
@add_start_docstrings(
@ -85,23 +88,30 @@ class Qwen2VLVideoProcessorInitKwargs(VideosKwargs):
The temporal patch size of the vision encoder.
merge_size (`int`, *optional*, defaults to 2):
The merge size of the vision encoder to llm encoder.
min_frames (`int`, *optional*, defaults to 4):
The minimum number of frames that can be sampled.
max_frames (`int`, *optional*, defaults to 768):
The maximum number of frames that can be sampled.
""",
)
@requires(backends=("torchvision",))
class Qwen2VLVideoProcessor(BaseVideoProcessor):
resample = PILImageResampling.BICUBIC
size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
size = {"shortest_edge": 128 * 28 * 28, "longest_edge": 28 * 28 * 768}
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
min_pixels = 56 * 56
max_pixels = 28 * 28 * 1280
min_pixels = 128 * 28 * 28
max_pixels = 28 * 28 * 768
patch_size = 14
temporal_patch_size = 2
merge_size = 2
min_frames = 4
max_frames = 768
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
valid_kwargs = Qwen2VLVideoProcessorInitKwargs
model_input_names = ["pixel_values_videos", "video_grid_thw"]
@ -109,9 +119,80 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
super().__init__(**kwargs)
self.size = {"shortest_edge": self.min_pixels, "longest_edge": self.max_pixels}
def sample_frames(
self,
video: "torch.Tensor",
frame_factor: int,
min_frames: int,
max_frames: int,
metadata: Optional[Union[VideoMetadata, dict]] = None,
num_frames: Optional[int] = None,
fps: Optional[int] = None,
):
"""
Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames.
If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames`
and `fps` are mutually exclusive.
Args:
video (`torch.Tensor`):
Video that need to be sampled.
frame_factor (`int`):
The temporal patch size of the vision encoder. Number of sampled frames will be rounded to be divisible by frame factor.
min_frames (`int`):
The minimum number of frames that can be sampled.
max_frames (`int`):
The maximum number of frames that can be sampled.
metadata (`VideoMetadata`, *optional*):
Metadata of the video containing information about total duration, fps and total number of frames.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
fps (`int`, *optional*):
Target frames to sample per second. Defaults to `self.fps`.
Returns:
torch.Tensor:
Sampled video frames.
"""
if fps is not None and num_frames is not None:
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
num_frames = num_frames if num_frames is not None else self.num_frames
fps = fps if fps is not None else self.fps
total_num_frames = video.shape[0]
# If num_frames is not given but fps is, calculate num_frames from fps
if num_frames is not None:
num_frames = round(num_frames / frame_factor) * frame_factor
elif fps is not None:
if metadata is None:
raise ValueError(
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
"Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
)
max_frames = math.floor(min(max_frames, total_num_frames) / frame_factor) * frame_factor
num_frames = total_num_frames / metadata["fps"] * fps
num_frames = min(min(max(num_frames, min_frames), max_frames), total_num_frames)
num_frames = math.floor(num_frames / frame_factor) * frame_factor
if num_frames > total_num_frames:
raise ValueError(
f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
"Decrease `num_frames` or `fps` for sampling."
)
if num_frames is not None:
indices = torch.arange(0, total_num_frames, total_num_frames / num_frames).int()
else:
indices = torch.arange(0, total_num_frames).int()
video = video[indices].contiguous()
return video
def _preprocess(
self,
videos: List["torch.Tensor"],
video_metadata: Union[List[VideoMetadata], List[dict]],
do_convert_rgb: bool,
do_resize: bool,
size: SizeDict,
@ -119,6 +200,7 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
do_sample_frames: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
min_pixels: Optional[int] = None,
@ -126,9 +208,28 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
patch_size: Optional[int] = None,
temporal_patch_size: Optional[int] = None,
merge_size: Optional[int] = None,
fps: Optional[int] = None,
num_frames: Optional[int] = None,
min_frames: Optional[int] = None,
max_frames: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
):
if do_sample_frames:
# Sample video frames
videos = [
self.sample_frames(
video,
frame_factor=temporal_patch_size,
min_frames=min_frames,
max_frames=max_frames,
metadata=metadata,
num_frames=num_frames,
fps=fps,
)
for video, metadata in zip(videos, video_metadata)
]
# Group videos by size for batched resizing
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
resized_videos_grouped = {}

View File

@ -16,18 +16,15 @@
Processor class for SmolVLM.
"""
import copy
from datetime import timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, make_nested_list_of_images
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import AllKwargsForChatTemplate, ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import BatchEncoding, TextInput
from ...utils import is_num2words_available, is_vision_available, logging
from ...video_utils import VideoInput, load_video, make_batched_videos
from ...video_utils import VideoInput
if is_vision_available():
@ -35,7 +32,13 @@ if is_vision_available():
DEFAULT_MEDIA_OUTTRO,
DEFAULT_VIDEO_INTRO,
FRAME_TIMESTAMP_MESSAGE,
smolvlm_sample_indices_fn,
)
if is_vision_available():
from .video_processing_smolvlm import (
DEFAULT_MEDIA_OUTTRO,
DEFAULT_VIDEO_INTRO,
FRAME_TIMESTAMP_MESSAGE,
)
if TYPE_CHECKING:
@ -50,6 +53,10 @@ else:
num2words = None
# The correct chat template to be used for videos after #38105
DEFAULT_CHAT_TEMPLATE = "<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% elif line['type'] == 'video' %}{{ '<video>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
def _prompt_split_image(
image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_image_token
):
@ -140,9 +147,7 @@ class SmolVLMProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer", "video_processor"]
image_processor_class = "SmolVLMImageProcessor"
video_processor_class = (
"SmolVLMImageProcessor" # TODO: raushan should be VideoProcessor when LANCZOS resizing is settled
)
video_processor_class = "SmolVLMVideoProcessor" # NOTE: uses different interpolation than slow processors
tokenizer_class = "AutoTokenizer"
def __init__(
@ -160,17 +165,7 @@ class SmolVLMProcessor(ProcessorMixin):
self.end_of_utterance_token = getattr(tokenizer, "end_of_utterance_token", "<end_of_utterance>")
self.global_image_token = getattr(tokenizer, "global_image_token", "<global-img>")
self.image_seq_len = image_seq_len
self.video_size = video_processor.video_sampling["video_size"]
self.image_size = image_processor.size
self.do_image_splitting = image_processor.do_image_splitting
self.do_video_splitting = video_processor.video_sampling.get("do_image_splitting", False)
self.default_max_frames = video_processor.video_sampling["max_frames"]
self.default_fps = video_processor.video_sampling["fps"]
# Matches one or more occurrences of <row_x_col_y> tags (where x and y are digits, optionally surrounded by newline characters
# self._regex_to_remove_extra_special_tokens = re.compile(r"(<row_\d+_col_\d+>\n?)+")
self.video_token = getattr(tokenizer, "video_token", "<video>")
if not num2words:
raise ImportError(
@ -179,16 +174,12 @@ class SmolVLMProcessor(ProcessorMixin):
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
def process_vision(
self, text, images, output_kwargs, do_image_splitting=False, image_processor_size=None, processor=None
):
def process_vision(self, text, images, output_kwargs):
if text is not None:
n_images_in_text = [sample.count(self.image_token) for sample in text]
n_images_in_images = [len(sublist) for sublist in images]
image_inputs = processor(
images, do_image_splitting=do_image_splitting, size=image_processor_size, **output_kwargs
)
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
if text is None:
return None, image_inputs
@ -227,6 +218,50 @@ class SmolVLMProcessor(ProcessorMixin):
return prompt_strings, image_inputs
def process_video(self, text, videos, output_kwargs):
if text is not None:
n_videos_in_text = [sample.count(self.video_token) for sample in text]
n_videos_in_videos = [len(sublist) for sublist in videos]
video_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
num_frames = video_inputs["pixel_values"].shape[1]
batch_timestamps = iter(video_inputs.pop("timestamps"))
batch_durations = iter(video_inputs.pop("durations"))
if text is None:
return None, video_inputs
if n_videos_in_videos != n_videos_in_text:
raise ValueError(
f"The number of videos in the text {n_videos_in_text} and videos {n_videos_in_videos} should be the same."
)
prompt_strings = []
for sample in text:
while self.video_token in sample:
timestamps = next(batch_timestamps)
duration = next(batch_durations)
duration_td = timedelta(seconds=int(duration))
image_prompt_strings = DEFAULT_VIDEO_INTRO.format(
frame_count=num2words(num_frames), video_duration=str(duration_td)
)
for timestamp in timestamps:
image_prompt_string = _prompt_single_image(
self.image_seq_len,
image_token=self.image_token,
fake_token_around_image=self.fake_image_token,
global_image_token=self.global_image_token,
)
timestamp = f"{timestamp[0]:02d}:{timestamp[1]:02d}"
image_prompt_string = FRAME_TIMESTAMP_MESSAGE.format(timestamp=timestamp) + image_prompt_string
image_prompt_strings += image_prompt_string
image_prompt_strings += DEFAULT_MEDIA_OUTTRO
sample = sample.replace(self.video_token, image_prompt_strings, 1)
prompt_strings.append(sample)
return prompt_strings, video_inputs
def __call__(
self,
images: Union[ImageInput, List[ImageInput], List[List[ImageInput]]] = None,
@ -310,21 +345,14 @@ class SmolVLMProcessor(ProcessorMixin):
text, vision_inputs = self.process_vision(
text,
images,
output_kwargs["images_kwargs"],
do_image_splitting=self.do_image_splitting,
image_processor_size=self.image_size,
processor=self.image_processor,
output_kwargs,
)
inputs.update(vision_inputs)
elif videos is not None:
videos = make_batched_videos(videos)
text, vision_inputs = self.process_vision(
text, vision_inputs = self.process_video(
text,
videos,
output_kwargs["videos_kwargs"],
do_image_splitting=self.do_image_splitting,
image_processor_size=self.video_size,
processor=self.video_processor,
output_kwargs,
)
inputs.update(vision_inputs)
@ -337,93 +365,6 @@ class SmolVLMProcessor(ProcessorMixin):
return BatchFeature(inputs, tensor_type=return_tensors)
def _process_messages_for_chat_template(
self,
conversations: List[List[Dict[str, str]]],
batch_images: List[ImageInput],
batch_videos: List[VideoInput],
batch_video_metadata: List[List[Dict[str, any]]],
**chat_template_kwargs,
):
"""
Used within `apply_chat_template` when a model has special way to process conversation history. For example,
video models might want to specify in the prompt the duration of video or which frame indices at which timestamps
were sampled. This information cannot be accessed before the video is loaded.
For most models it is a no-op, must be overridden by model processors which require special processing.
Args:
conversation (`List[Dict, str, str]`):
The conversation to process. Always comes in batched format.
batch_images (`List[List[ImageInput]]`):
Batch of images that were loaded from url/path defined in the conversation. The images
are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL` images
per batch.
batch_videos (`List[List[ImageInput]]`):
Batch of videos that were loaded from url/path defined in the conversation. The videos
are ordered in the same way as in the conversation. Comes in nested list format, one list of 4D video arrays
per batch.
batch_video_metadata (`List[List[Dict[[str, any]]]]`):
Batch of metadata returned from loading videos. That includes video fps, duration and total number of framer in original video.
Metadata are ordered in the same way as `batch_videos`. Comes in nested list format, one list of 4D video arrays
per batch.
"""
# We don't want to modify in-place the messages passed by user
# The user might want to add new turn on conv and continue generation
conversations = copy.deepcopy(conversations)
batch_num_frames, batch_timestamps = [], []
for metadata_list, video_list in zip(batch_video_metadata, batch_videos):
for metadata, video in zip(metadata_list, video_list):
duration_sec = getattr(metadata, "duration")
frames_idx = getattr(metadata, "frames_indices")
fps = getattr(metadata, "fps")
timestamps = []
for idx, frame_np in zip(frames_idx, video):
sec = idx / fps
mm = int(sec // 60)
ss = int(sec % 60)
timestamps.append(f"{mm:02d}:{ss:02d}")
batch_timestamps.append(timestamps)
batch_num_frames.append(len(video))
for conversation in conversations:
# For each message, scan content for {"type": "video"}
for msg in conversation:
if "content" not in msg:
continue
new_content = []
for block in msg["content"]:
if block.get("type") == "video":
curr_timestamps = batch_timestamps.pop(0)
curr_num_frames = batch_num_frames.pop(0)
# Build the video intro texts
td = timedelta(seconds=int(duration_sec))
new_content.append(
{
"type": "text",
"text": DEFAULT_VIDEO_INTRO.format(
frame_count=num2words(curr_num_frames), video_duration=str(td)
),
}
)
# 2) Insert per-frame lines: "Frame from {timestamp}:", then an "image" block
for i, ts in enumerate(curr_timestamps):
new_content.append({"type": "text", "text": FRAME_TIMESTAMP_MESSAGE.format(timestamp=ts)})
new_content.append({"type": "image"})
# 3) Optionally add an outro (e.g. "Now answer the question:")
new_content.append({"type": "text", "text": DEFAULT_MEDIA_OUTTRO})
# Do NOT add the original block => we skip it (since we've replaced it)
else:
# keep original block
new_content.append(block)
# update the content
msg["content"] = new_content
return conversations
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to SmolVLMTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
@ -446,45 +387,54 @@ class SmolVLMProcessor(ProcessorMixin):
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(image_processor_input_names + tokenizer_input_names))
# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
def _load_video_for_model(
def apply_chat_template(
self,
video: Union[str, "VideoInput"],
num_frames: Optional[int] = None,
fps: Optional[int] = None,
backend: str = "opencv",
skip_secs: int = 0.0,
**kwargs,
) -> np.array:
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
chat_template: Optional[str] = None,
**kwargs: Unpack[AllKwargsForChatTemplate],
) -> str:
"""
Loads `video` to a numpy array.
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
conversations to turn them into a single tokenizable string.
The input is expected to be in the following format, where each message content is a list consisting of text and
optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
`pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.
conversation = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Please describe this image in detail."},
],
},
]
Args:
video (`str` or `VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
backend (`str`, *optional*, defaults to `"opencv"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
Returns:
Tuple[`np.array`, Dict]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- Metadata dictionary.
conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`):
The conversation to format.
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
chat template is used.
"""
max_frames = self.default_max_frames if num_frames is None else num_frames
target_fps = self.default_fps if fps is None else fps
if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
):
conversations = conversation
else:
conversations = [conversation]
def sample_indices_fn_func(metadata, **fn_kwargs):
return smolvlm_sample_indices_fn(
metadata, max_frames=max_frames, target_fps=target_fps, skip_secs=skip_secs, **fn_kwargs
)
video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func)
return video, metadata
has_video = any(
(isinstance(content, dict) and content["type"] == "video")
for conversation in conversations
for message in conversation
for content in message["content"]
)
if chat_template is None and has_video:
# re-assign to the correct default template for BC, if user is not requesting their own template
chat_template = DEFAULT_CHAT_TEMPLATE
return super().apply_chat_template(conversation, chat_template, **kwargs)
__all__ = ["SmolVLMProcessor"]

View File

@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import numpy as np
from ...image_processing_utils import (
BatchFeature,
get_size_dict,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
@ -38,7 +38,7 @@ from ...utils.import_utils import requires
from ...video_processing_utils import (
BaseVideoProcessor,
)
from ...video_utils import group_videos_by_shape, reorder_videos
from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos
if is_vision_available():
@ -68,66 +68,6 @@ FRAME_TIMESTAMP_MESSAGE = "\nFrame from {timestamp}:"
MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
def smolvlm_sample_indices_fn(metadata, max_frames, target_fps, skip_secs=0):
"""
Example sampling function which:
- Uses `max_frames` (if provided) or calculates it from `fps` and metadata.
- Applies a basic center-skip if fewer frames than available, otherwise
optionally skips `skip_secs` from both the start and end.
- Uniformly samples the desired number of frames between the start and end indices.
Args:
max_frames (`int`):
Maximum number of frames to sample.
target_fps (`int`):
Target frames to sample per second.
metadata (`dict`):
Contains video metadata such as "n_frames" and "video_fps".
skip_secs (`float`, *optional*, defaults to 1.0):
Number of seconds to skip from the start and end if the video is long enough.
Returns:
numpy.ndarray:
An array of unique frame indices to sample.
"""
total_num_frames = getattr(metadata, "total_num_frames", 0)
if total_num_frames <= 0:
raise ValueError(f"Invalid total_num_frames={total_num_frames} in metadata.")
native_fps = getattr(metadata, "fps", 30.0)
duration_seconds = getattr(metadata, "duration", 0)
if duration_seconds <= 0:
raise ValueError(f"Invalid duration_seconds={duration_seconds} in metadata.")
# Step 1) Estimate how many frames we'd sample at `target_fps`, fallback if target_fps <= 0
estimated_frames = int(round(target_fps * duration_seconds))
# Step 2) desired_frames
desired_frames = min(estimated_frames, max_frames)
if desired_frames < 1:
desired_frames = 1
# Step 3) center skip logic
start_idx = 0
end_idx = total_num_frames - 1
if skip_secs > 0 and (duration_seconds - 2 * skip_secs) > (max_frames * target_fps):
start_idx = int(skip_secs * native_fps)
end_idx = int(total_num_frames - skip_secs * native_fps)
start_idx = max(0, start_idx)
end_idx = min(end_idx, total_num_frames - 1)
if start_idx >= end_idx:
start_idx, end_idx = 0, total_num_frames - 1
indices = np.linspace(start_idx, end_idx, desired_frames, dtype=int)
indices = np.unique(indices)
return indices
def get_max_height_width(videos: list["torch.Tensor"]) -> List[int]:
"""
Get the maximum height and width across all videos in a batch.
@ -180,13 +120,15 @@ def get_resize_output_image_size(
return height, width
class SmolVLMVideoProcessorInitKwargs(VideosKwargs): ...
class SmolVLMVideoProcessorInitKwargs(VideosKwargs):
max_image_size: dict[str, int] = None
@requires(backends=("torchvision",))
class SmolVLMVideoProcessor(BaseVideoProcessor):
resample = PILImageResampling.LANCZOS
size = {"longest_edge": 4 * 364}
max_image_size = {"longest_edge": 364}
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
do_resize = True
@ -194,11 +136,21 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
do_normalize = True
do_convert_rgb = True
do_pad = True
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
valid_kwargs = SmolVLMVideoProcessorInitKwargs
model_input_names = ["pixel_values", "pixel_attention_mask"]
def __init__(self, **kwargs: Unpack[SmolVLMVideoProcessorInitKwargs]):
super().__init__(**kwargs)
# For BC pop values from `config.video_sampling`. In official config `video_sampling` is guaranteed to be present
# We check for `Noneness` only for certain tests such as `test_init_without_params`
if "size" in kwargs and "video_sampling" in kwargs:
kwargs["video_sampling"]["video_size"] = kwargs["size"]
if "video_sampling" in kwargs:
self.num_frames = kwargs["video_sampling"]["max_frames"]
self.fps = kwargs["video_sampling"]["fps"]
self.size = get_size_dict(kwargs["video_sampling"]["video_size"], default_to_square=self.default_to_square)
def resize(
self,
@ -240,12 +192,20 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
new_size = (size.height, size.width)
else:
raise ValueError(f"Size must contain 'height' and 'width' keys, or 'longest_edge' key. Got {size}.")
return F.resize(video, new_size, interpolation=interpolation, antialias=antialias)
video = F.resize(video, new_size, interpolation=interpolation, antialias=antialias)
# Resize again to match image processor when `do_image_splitting=False`. Frames have to be squared to `max_image_size`
# NOTE: videos are always processoed without image splitting
max_size = self.max_image_size["longest_edge"], self.max_image_size["longest_edge"]
video = F.resize(video, max_size, interpolation=interpolation, antialias=antialias)
return video
def pad(
self,
video: "torch.Tensor",
padded_size: tuple[int, int],
max_num_frames: int,
fill: int = 0,
return_pixel_mask: bool = True,
):
@ -255,24 +215,28 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
Video to pad.
padded_size (`Tuple[int, int]`):
Height and width to pad.
max_num_frames (`int`):
The maximum number of frames to which video will be padded.
fill (`int`, *optional*):
The value to use for the padding.
return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether to return a pixel mask.
"""
original_size = video.size()[-2:]
padding_bottom = padded_size[0] - original_size[0]
padding_right = padded_size[1] - original_size[1]
if padding_bottom < 0 or padding_right < 0:
padding_height = padded_size[0] - original_size[0]
padding_width = padded_size[1] - original_size[1]
padding_frame = max_num_frames - video.shape[0]
if padding_width < 0 or padding_height < 0:
raise ValueError(
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
f"original size. Got padded size: {padded_size}, original size: {original_size}."
)
if original_size != padded_size:
padding = [0, 0, padding_right, padding_bottom]
padding = [0, padding_width, 0, padding_height, 0, 0, 0, padding_frame]
video = F.pad(video, padding, fill=fill)
# Make a pixel mask for the video, where 1 indicates a valid pixel and 0 indicates padding.
# Mask shape is (num_frames, height, width) so we omit the channel dim
pixel_mask = None
if return_pixel_mask:
pixel_mask = torch.zeros_like(video[..., 0, :, :], dtype=torch.int64)
@ -280,9 +244,79 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
return video, pixel_mask
def sample_frames(
self,
video: "torch.Tensor",
metadata: Union[VideoMetadata, dict],
num_frames: Optional[int] = None,
fps: Optional[int] = None,
skip_secs: Optional[int] = 1,
):
"""
Video sampling function which:
- Uses `num_frames` (if provided) or calculates it from `fps` and metadata.
- Applies a basic center-skip if fewer frames than available, otherwise
optionally skips `skip_secs` from both the start and end.
- Uniformly samples the desired number of frames between the start and end indices.
Args:
video (`torch.Tensor`):
Video that need to be sampled.
metadata (`VideoMetadata`):
Metadata of the video containing information about total duration, fps and total number of frames.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
fps (`int`, *optional*):
Target frames to sample per second. Defaults to `self.fps`.
skip_secs (`float`, *optional*, defaults to `1`):
Number of seconds to skip from the start and end if the video is long enough.
Returns:
torch.Tensor:
Sampled video frames.
"""
num_frames = num_frames if num_frames is not None else self.num_frames
fps = fps if fps is not None else self.fps
total_num_frames = video.shape[0]
# Step 1) Estimate how many frames we'd sample at `target_fps`, fallback if target_fps <= 0
estimated_frames = int(round(fps * metadata["duration"]))
# Step 2) desired_frames
desired_frames = min(estimated_frames, num_frames)
if desired_frames < 1:
desired_frames = 1
# Step 3) center skip logic
start_idx = 0
end_idx = total_num_frames - 1
if skip_secs > 0 and (metadata["duration"] - 2 * skip_secs) > (num_frames * fps):
start_idx = int(skip_secs * metadata["fps"])
end_idx = int(total_num_frames - skip_secs * metadata["fps"])
start_idx = max(0, start_idx)
end_idx = min(end_idx, total_num_frames - 1)
if start_idx >= end_idx:
start_idx, end_idx = 0, total_num_frames - 1
indices = np.linspace(start_idx, end_idx, desired_frames, dtype=int)
indices = np.unique(indices)
video = video[indices].contiguous()
timestamps = []
for idx in indices:
sec = idx / metadata["fps"]
mm = int(sec // 60)
ss = int(sec % 60)
timestamps.append([mm, ss])
return video, timestamps, int(metadata["duration"])
def _preprocess(
self,
videos: List["torch.Tensor"],
video_metadata: Union[List[VideoMetadata], List[dict]],
do_convert_rgb: bool,
do_resize: bool,
size: SizeDict,
@ -291,13 +325,38 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
rescale_factor: float,
do_normalize: bool,
do_pad: bool,
do_sample_frames: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
fps: Optional[int] = None,
num_frames: Optional[int] = None,
skip_secs: Optional[int] = 0,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
):
# Group videos by size for batched resizing
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
if do_sample_frames:
if video_metadata[0] is None:
raise ValueError(
"Frame sampling is enabled but no video metadata was found. SmolVLM requires metadata to correctly sample frames. "
"Please pass in `VideoMetadata` object per each input video or set `do_sample_frames=False`"
)
processed_videos = []
timestamps_list, durations_list = [], []
for video, metadata in zip(videos, video_metadata):
video, timestamps, duration = self.sample_frames(video, metadata, num_frames, fps, skip_secs)
timestamps_list.append(timestamps)
durations_list.append(duration)
processed_videos.append(video)
else:
# Assume 24 fps by default and prepare timestamps for the whole video when all frames are sampled
processed_videos = videos
timestamps_list = [
[(int((idx / 24) // 60), int((idx / 24) % 60)) for idx in range(len(video))] for video in videos
]
durations_list = [len(video) // 24 for video in videos]
grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos)
resized_videos_grouped = {}
for shape, stacked_videos in grouped_videos.items():
if do_convert_rgb:
@ -319,12 +378,15 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
if do_pad:
pad_size = get_max_height_width(processed_videos)
max_num_frames = max(len(video) for video in processed_videos)
grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos)
processed_padded_mask_grouped = {}
processed_videos_grouped = {}
for shape, stacked_videos in grouped_videos.items():
stacked_videos, padded_masks = self.pad(stacked_videos, padded_size=pad_size)
stacked_videos, padded_masks = self.pad(
stacked_videos, padded_size=pad_size, max_num_frames=max_num_frames
)
processed_videos_grouped[shape] = stacked_videos
processed_padded_mask_grouped[shape] = padded_masks
@ -332,7 +394,7 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
pixel_attention_mask = reorder_videos(processed_padded_mask_grouped, grouped_videos_index)
processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos
data = {"pixel_values": processed_videos}
data = {"pixel_values": processed_videos, "timestamps": timestamps_list, "durations": durations_list}
if do_pad:
data["pixel_attention_mask"] = (

View File

@ -46,6 +46,7 @@ class VideoLlavaVideoProcessor(BaseVideoProcessor):
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
valid_kwargs = VideoLlavaFastVideoProcessorInitKwargs
model_input_names = ["pixel_values_videos"]

View File

@ -24,7 +24,7 @@ import typing
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, TypedDict, Union
from typing import Any, Dict, Optional, TypedDict, Union
import numpy as np
import typing_extensions
@ -33,9 +33,9 @@ from huggingface_hub.errors import EntryNotFoundError
from .audio_utils import load_audio
from .dynamic_module_utils import custom_object_save
from .feature_extraction_utils import BatchFeature
from .image_utils import ChannelDimension, ImageInput, is_valid_image, is_vision_available, load_image
from .image_utils import ChannelDimension, is_valid_image, is_vision_available, load_image
from .utils.chat_template_utils import render_jinja_template
from .video_utils import VideoInput, load_video
from .video_utils import VideoMetadata, load_video
if is_vision_available():
@ -64,6 +64,7 @@ from .utils import (
list_repo_templates,
logging,
)
from .utils.deprecation import deprecate_kwarg
logger = logging.get_logger(__name__)
@ -235,6 +236,14 @@ class VideosKwargs(TypedDict, total=False):
Whether to pad the video to the `(max_height, max_width)` of the videos in the batch.
do_center_crop (`bool`, *optional*):
Whether to center crop the video.
do_sample_frames (`bool`, *optional*):
Whether to sample frames from the video before processing or to process the whole video.
video_metadata (`VideoMetadata`, *optional*):
Metadata of the video containing information about total duration, fps and total number of frames.
num_frames (`int`, *optional*):
Maximum number of frames to sample when `do_sample_frames=True`.
fps (`int`, *optional*):
Target frames to sample per second when `do_sample_frames=True`.
crop_size (`Dict[str, int]`, *optional*):
Desired output size when applying center-cropping.
data_format (`ChannelDimension` or `str`, *optional*):
@ -260,6 +269,10 @@ class VideosKwargs(TypedDict, total=False):
data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]]
device: Optional[str]
do_sample_frames: Optional[bool]
video_metadata: Optional[Union[VideoMetadata, dict]]
fps: Optional[int]
num_frames: Optional[int]
class AudioKwargs(TypedDict, total=False):
@ -409,9 +422,6 @@ class ChatTemplateLoadKwargs(TypedDict, total=False):
The backend to use when loading the video which will be used only when there are videos in the conversation.
Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav" because it is the only backend
that supports all types of sources to load from.
video_fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
@ -424,9 +434,7 @@ class ChatTemplateLoadKwargs(TypedDict, total=False):
return np.linspace(start_idx, end_idx, num_frames, dtype=int)
"""
num_frames: Optional[int] = None
video_load_backend: Optional[str] = "pyav"
video_fps: Optional[int] = None
sampling_rate: Optional[int] = 16_000
load_audio_from_video: Optional[bool] = False
@ -1371,40 +1379,7 @@ class ProcessorMixin(PushToHubMixin):
)
return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)}
def _process_messages_for_chat_template(
self,
conversation: List[List[Dict[str, str]]],
batch_images: List[ImageInput],
batch_videos: List[VideoInput],
batch_video_metadata: List[List[Dict[str, any]]],
**mm_load_kwargs: Unpack[ChatTemplateLoadKwargs],
):
"""
Used within `apply_chat_template` when a model has a special way to process conversation history. For example,
video models might want to specify in the prompt the duration of video or which frame indices at which timestamps
were sampled. This information cannot be accessed before the video is loaded.
For most models it is a no-op, and must be overridden by model processors which require special processing.
Args:
conversation (`List[Dict, str, str]`):
The conversation to process. Always comes in batched format.
batch_images (`List[List[ImageInput]]`):
Batch of images that were loaded from url/path defined in the conversation. The images
are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL` images
per batch.
batch_videos (`List[List[ImageInput]]`):
Batch of videos that were loaded from url/path defined in the conversation. The videos
are ordered in the samm way as in the conversation. Comes in nested list format, one list of 4D video arrays
per batch.
batch_video_metadata (`List[List[Dict[[str, any]]]]`):
Batch of metadata returned from loading videos. That includes video fps, duration and total number of framer in original video.
Metadata are ordered in the same way as `batch_videos`. Comes in nested list format, one list of 4D video arrays
per batch.
"""
return conversation
@deprecate_kwarg("video_fps", version="4.58", new_name="fps")
def apply_chat_template(
self,
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
@ -1423,7 +1398,7 @@ class ProcessorMixin(PushToHubMixin):
{
"role": "user",
"content": [
{"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Please describe this image in detail."},
],
},
@ -1436,7 +1411,6 @@ class ProcessorMixin(PushToHubMixin):
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
chat template is used.
"""
if chat_template is None:
if isinstance(self.chat_template, dict) and "default" in self.chat_template:
chat_template = self.chat_template["default"]
@ -1545,16 +1519,12 @@ class ProcessorMixin(PushToHubMixin):
metadata = None
logger.warning(
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
"If your model uses this metadata during processing, please load the whole video and let the model sample frames instead."
"If your model requires metadata during processing, please load the whole video and let the processor sample frames instead."
)
else:
# TODO: raushan, should be `self.video_processor.load_video_for_model` when API is added
video, metadata = self._load_video_for_model(
video, metadata = load_video(
fname,
num_frames=mm_load_kwargs.get("num_frames", None),
fps=mm_load_kwargs.get("video_fps", None),
backend=mm_load_kwargs["video_load_backend"],
**kwargs,
)
videos.append(video)
video_metadata.append(metadata)
@ -1567,15 +1537,6 @@ class ProcessorMixin(PushToHubMixin):
batch_videos.append(videos)
batch_video_metadata.append(video_metadata)
# Process conversation with video/image information if needed. Then convert into a prompt using Jinja template
conversations = self._process_messages_for_chat_template(
conversations,
batch_images=batch_images,
batch_videos=batch_videos,
batch_video_metadata=batch_video_metadata,
**processed_kwargs["mm_load_kwargs"],
)
prompt, generation_indices = render_jinja_template(
conversations=conversations,
chat_template=chat_template,
@ -1597,11 +1558,17 @@ class ProcessorMixin(PushToHubMixin):
if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token):
kwargs["add_special_tokens"] = False
# Always sample frames by default unless explicitly set to `False` by users. If users do not pass `num_frames`/`video_fps`
# sampling should not done for BC.
if "do_sample_frames" not in kwargs and ("fps" in kwargs or "num_frames" in kwargs):
kwargs["do_sample_frames"] = True
out = self(
text=prompt,
images=batch_images if batch_images else None,
videos=batch_videos if batch_videos else None,
audio=batch_audios if batch_audios else None,
video_metadata=batch_video_metadata,
**kwargs,
)
if return_dict:
@ -1626,38 +1593,6 @@ class ProcessorMixin(PushToHubMixin):
return out["input_ids"]
return prompt
# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
# Keep private so we can simply remove when needed
def _load_video_for_model(
self,
video: Union[str, "VideoInput"],
num_frames: Optional[int] = None,
fps: Optional[int] = None,
backend: str = "opencv",
**kwargs,
) -> np.array:
"""
Loads `video` to a numpy array.
Args:
video (`str` or `VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
backend (`str`, *optional*, defaults to `"opencv"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
Returns:
Tuple[`np.array`, Dict]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- Metadata dictionary.
"""
video, metadata = load_video(video, num_frames, fps=fps, backend=backend)
return video, metadata
def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
"""
Post-process the output of a vlm to decode the text.

View File

@ -51,6 +51,7 @@ from .utils import (
from .utils.import_utils import requires
from .video_utils import (
VideoInput,
VideoMetadata,
group_videos_by_shape,
load_video,
make_batched_videos,
@ -118,6 +119,14 @@ BASE_VIDEO_PROCESSOR_DOCSTRING = r"""
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `self.image_std`):
Whether to convert the video to RGB.
video_metadata (`VideoMetadata`, *optional*):
Metadata of the video containing information about total duration, fps and total number of frames.
do_sample_frames (`int`, *optional*, defaults to `self.do_sample_frames`):
Whether to sample frames from the video before processing or to process the whole video.
num_frames (`int`, *optional*, defaults to `self.num_frames`):
Maximum number of frames to sample when `do_sample_frames=True`.
fps (`int`, *optional*, defaults to `self.fps`):
Target frames to sample per second when `do_sample_frames=True`.
return_tensors (`str` or `TensorType`, *optional*):
Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
@ -157,6 +166,10 @@ class BaseVideoProcessor(BaseImageProcessorFast):
rescale_factor = 1 / 255
do_normalize = None
do_convert_rgb = None
do_sample_frames = None
fps = None
num_frames = None
video_metadata = None
valid_kwargs = VideosKwargs
model_input_names = ["pixel_values_videos"]
@ -219,9 +232,67 @@ class BaseVideoProcessor(BaseImageProcessorFast):
video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., :3, :, :]
return video
def sample_frames(
self,
video: "torch.Tensor",
metadata: Optional[Union[VideoMetadata, dict]] = None,
num_frames: Optional[int] = None,
fps: Optional[int] = None,
):
"""
Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames.
If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames`
and `fps` are mutually exclusive.
Args:
video (`torch.Tensor`):
Video that need to be sampled.
metadata (`VideoMetadata`, *optional*):
Metadata of the video containing information about total duration, fps and total number of frames.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
fps (`int`, *optional*):
Target frames to sample per second. Defaults to `self.fps`.
Returns:
torch.Tensor:
Sampled video frames.
"""
if fps is not None and num_frames is not None:
raise ValueError(
"`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
)
num_frames = num_frames if num_frames is not None else self.num_frames
fps = fps if fps is not None else self.fps
total_num_frames = video.shape[0]
# If num_frames is not given but fps is, calculate num_frames from fps
if num_frames is None and fps is not None:
if metadata is None:
raise ValueError(
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
"Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
)
num_frames = int(total_num_frames / metadata["fps"] * fps)
if num_frames > total_num_frames:
raise ValueError(
f"Video can't be sampled. The `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
)
if num_frames is not None:
indices = torch.arange(0, total_num_frames, total_num_frames / num_frames).int()
else:
indices = torch.arange(0, total_num_frames).int()
video = video[indices].contiguous()
return video
def _prepare_input_videos(
self,
videos: VideoInput,
video_metadata: VideoMetadata = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
device: Optional["torch.device"] = None,
) -> List["torch.Tensor"]:
@ -229,6 +300,11 @@ class BaseVideoProcessor(BaseImageProcessorFast):
Prepare the input videos for processing.
"""
videos = make_batched_videos(videos)
if video_metadata is not None:
batch_metadata = [metadata for batch_list in video_metadata for metadata in batch_list]
else:
batch_metadata = [None] * len(videos)
processed_videos = []
for video in videos:
# `make_batched_videos` always returns a 4D array per video
@ -242,7 +318,7 @@ class BaseVideoProcessor(BaseImageProcessorFast):
video = video.to(device)
processed_videos.append(video)
return processed_videos
return processed_videos, batch_metadata
@add_start_docstrings(BASE_VIDEO_PROCESSOR_DOCSTRING)
def preprocess(
@ -261,7 +337,10 @@ class BaseVideoProcessor(BaseImageProcessorFast):
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
videos = self._prepare_input_videos(videos=videos, input_data_format=input_data_format, device=device)
video_metadata = kwargs.pop("video_metadata")
videos, video_metadata = self._prepare_input_videos(
videos=videos, video_metadata=video_metadata, input_data_format=input_data_format, device=device
)
kwargs = self._further_process_kwargs(**kwargs)
self._validate_preprocess_kwargs(**kwargs)
@ -276,11 +355,12 @@ class BaseVideoProcessor(BaseImageProcessorFast):
kwargs.pop("default_to_square")
kwargs.pop("data_format")
return self._preprocess(videos=videos, **kwargs)
return self._preprocess(videos=videos, video_metadata=video_metadata, **kwargs)
def _preprocess(
self,
videos: List["torch.Tensor"],
video_metadata: Union[List[VideoMetadata], List[dict]],
do_convert_rgb: bool,
do_resize: bool,
size: SizeDict,
@ -294,8 +374,18 @@ class BaseVideoProcessor(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_sample_frames: Optional[bool] = None,
fps: Optional[int] = None,
num_frames: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
if do_sample_frames:
# Sample video frames
videos = [
self.sample_frames(video, metadata=metadata, num_frames=num_frames, fps=fps)
for video, metadata in zip(videos, video_metadata)
]
# Group videos by size for batched resizing
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
resized_videos_grouped = {}

View File

@ -74,6 +74,9 @@ class VideoMetadata:
duration: float
video_backend: str
def __getitem__(self, item):
return getattr(self, item)
def is_valid_video_frame(frame):
return isinstance(frame, PIL.Image.Image) or (
@ -163,7 +166,7 @@ def make_batched_videos(videos) -> List[Union["np.ndarray", "torch.Tensor"]]:
videos = [np.array(videos)[None, ...]]
# nested batch so we need to unflatten
elif isinstance(videos[0], (list, tuple)) and is_valid_video(videos[0][0]):
return [video for sublist in videos for video in sublist]
videos = [video for sublist in videos for video in sublist]
return convert_pil_frames_to_video(videos)

View File

@ -17,7 +17,6 @@ import shutil
import tempfile
import unittest
from huggingface_hub import hf_hub_download
from parameterized import parameterized
from transformers import AutoProcessor, AutoTokenizer, InternVLProcessor
@ -180,77 +179,6 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
)
images_patches_index += inputs["pixel_values"].shape[0]
# Override video chat_template tests as InternVLProcessor returns flattened video features
@require_av
@require_torch
def test_apply_chat_template_video_special_processing(self):
"""
Tests that models can use their own preprocessing to preprocess conversations.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def _process_messages_for_chat_template(
conversation,
batch_images,
batch_videos,
batch_video_metadata,
**chat_template_kwargs,
):
# Let us just always return a dummy prompt
new_msg = [
[
{
"role": "user",
"content": [
{"type": "video"}, # no need to use path, video is loaded already by this moment
{"type": "text", "text": "Dummy prompt for preprocess testing"},
],
},
]
]
return new_msg
processor._process_messages_for_chat_template = _process_messages_for_chat_template
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
num_frames=8,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
@require_torch
@require_av
def test_apply_chat_template_video_frame_sampling(self):
@ -393,13 +321,13 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenize=True,
return_dict=True,
return_tensors="pt",
num_frames=4, # by default no more than 4 frames, otherwise too slow
num_frames=2, # by default no more than 2 frames, otherwise too slow
)
self.assertTrue(self.videos_input_name in out_dict)
self.assertEqual(len(out_dict["input_ids"]), batch_size)
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
video_len = 4 if batch_size == 1 else 3 # InternVL patches out and removes frames after processing
video_len = 2 if batch_size == 1 else 3 # InternVL patches out and removes frames after processing
self.assertEqual(len(out_dict[self.videos_input_name]), video_len)
for k in out_dict:
self.assertIsInstance(out_dict[k], torch.Tensor)

View File

@ -407,14 +407,14 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenize=True,
return_dict=True,
return_tensors=return_tensors,
num_frames=4, # by default no more than 4 frames, otherwise too slow
num_frames=2, # by default no more than 2 frames, otherwise too slow
)
input_name = getattr(self, input_name)
self.assertTrue(input_name in out_dict)
self.assertEqual(len(out_dict["input_ids"]), batch_size)
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
video_len = 5760 if batch_size == 1 else 5808 # qwen pixels don't scale with bs same way as other models
video_len = 2880 if batch_size == 1 else 5808 # qwen pixels don't scale with bs same way as other models
mm_len = batch_size * 1564 if modality == "image" else video_len
self.assertEqual(len(out_dict[input_name]), mm_len)
@ -525,72 +525,6 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 2904)
@require_av
def test_apply_chat_template_video_special_processing(self):
"""
Tests that models can use their own preprocessing to preprocess conversations.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def _process_messages_for_chat_template(
conversation,
batch_images,
batch_videos,
batch_video_metadata,
**chat_template_kwargs,
):
# Let us just always return a dummy prompt
new_msg = [
[
{
"role": "user",
"content": [
{"type": "video"}, # no need to use path, video is loaded already by this moment
{"type": "text", "text": "Dummy prompt for preprocess testing"},
],
},
]
]
return new_msg
processor._process_messages_for_chat_template = _process_messages_for_chat_template
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 145912)
@require_librosa
@require_av
@unittest.skip(

View File

@ -19,7 +19,6 @@ import unittest
import numpy as np
import pytest
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, Qwen2Tokenizer
from transformers.testing_utils import require_av, require_torch, require_vision
@ -219,14 +218,14 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenize=True,
return_dict=True,
return_tensors=return_tensors,
num_frames=4, # by default no more than 4 frames, otherwise too slow
num_frames=2, # by default no more than 2 frames, otherwise too slow
)
input_name = getattr(self, input_name)
self.assertTrue(input_name in out_dict)
self.assertEqual(len(out_dict["input_ids"]), batch_size)
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
video_len = 360 if batch_size == 1 else 320 # qwen pixels don't scale with bs same way as other models
video_len = 180 if batch_size == 1 else 320 # qwen pixels don't scale with bs same way as other models
mm_len = batch_size * 192 if modality == "image" else video_len
self.assertEqual(len(out_dict[input_name]), mm_len)
@ -346,70 +345,3 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertEqual(inputs[self.images_input_name].shape[0], 612)
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertEqual(inputs[self.images_input_name].shape[0], 100)
@require_av
def test_apply_chat_template_video_special_processing(self):
"""
Tests that models can use their own preprocessing to preprocess conversations.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def _process_messages_for_chat_template(
conversation,
batch_images,
batch_videos,
batch_video_metadata,
**chat_template_kwargs,
):
# Let us just always return a dummy prompt
new_msg = [
[
{
"role": "user",
"content": [
{"type": "video"}, # no need to use path, video is loaded already by this moment
{"type": "text", "text": "Dummy prompt for preprocess testing"},
],
},
]
]
return new_msg
processor._process_messages_for_chat_template = _process_messages_for_chat_template
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 21960)

View File

@ -19,7 +19,6 @@ import unittest
import numpy as np
import pytest
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, Qwen2Tokenizer
from transformers.testing_utils import require_av, require_torch, require_vision
@ -220,14 +219,14 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenize=True,
return_dict=True,
return_tensors=return_tensors,
num_frames=4, # by default no more than 4 frames, otherwise too slow
num_frames=2, # by default no more than 2 frames, otherwise too slow
)
input_name = getattr(self, input_name)
self.assertTrue(input_name in out_dict)
self.assertEqual(len(out_dict["input_ids"]), batch_size)
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
video_len = 360 if batch_size == 1 else 320 # qwen pixels don't scale with bs same way as other models
video_len = 180 if batch_size == 1 else 320 # qwen pixels don't scale with bs same way as other models
mm_len = batch_size * 192 if modality == "image" else video_len
self.assertEqual(len(out_dict[input_name]), mm_len)
@ -337,73 +336,6 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 160)
@require_av
def test_apply_chat_template_video_special_processing(self):
"""
Tests that models can use their own preprocessing to preprocess conversations.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def _process_messages_for_chat_template(
conversation,
batch_images,
batch_videos,
batch_video_metadata,
**chat_template_kwargs,
):
# Let us just always return a dummy prompt
new_msg = [
[
{
"role": "user",
"content": [
{"type": "video"}, # no need to use path, video is loaded already by this moment
{"type": "text", "text": "Dummy prompt for preprocess testing"},
],
},
]
]
return new_msg
processor._process_messages_for_chat_template = _process_messages_for_chat_template
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 21960)
def test_kwargs_overrides_custom_image_processor_kwargs(self):
processor = self.get_processor()
self.skip_processor_without_typed_kwargs(processor)

View File

@ -99,8 +99,9 @@ class Qwen2VLVideoProcessingTester:
}
@require_vision
def expected_output_video_shape(self, videos):
grid_t = self.num_frames // self.temporal_patch_size
def expected_output_video_shape(self, videos, num_frames=None):
num_frames = num_frames if num_frames is not None else self.num_frames
grid_t = num_frames // self.temporal_patch_size
hidden_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size
seq_len = 0
for video in videos:
@ -289,3 +290,70 @@ class Qwen2VLVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
)[self.input_name]
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
def test_call_sample_frames(self):
for video_processing_class in self.video_processor_list:
video_processing = video_processing_class(**self.video_processor_dict)
prev_num_frames = self.video_processor_tester.num_frames
self.video_processor_tester.num_frames = 8
video_inputs = self.video_processor_tester.prepare_video_inputs(
equal_resolution=False,
return_tensors="torch",
)
# Force set sampling to False. No sampling is expected even when `num_frames` exists
video_processing.do_sample_frames = False
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=3)[self.input_name]
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=3)[self.input_name]
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
expected_output_video_shape_batched = self.video_processor_tester.expected_output_video_shape(video_inputs)
self.assertListEqual(list(encoded_videos.shape), expected_output_video_shape)
self.assertListEqual(list(encoded_videos_batched.shape), expected_output_video_shape_batched)
# Set sampling to True. Video frames should be sampled with `num_frames` in the output
video_processing.do_sample_frames = True
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=4)[self.input_name]
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=4)[self.input_name]
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(
[video_inputs[0]], num_frames=4
)
expected_output_video_shape_batched = self.video_processor_tester.expected_output_video_shape(
video_inputs, num_frames=4
)
self.assertListEqual(list(encoded_videos.shape), expected_output_video_shape)
self.assertListEqual(list(encoded_videos_batched.shape), expected_output_video_shape_batched)
# Sample with `fps` requires metadata to infer number of frames from total duration
with self.assertRaises(ValueError):
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", fps=3)[self.input_name]
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", fps=3)[self.input_name]
metadata = [[{"duration": 2.0, "total_num_frames": 8, "fps": 4}]]
batched_metadata = metadata * len(video_inputs)
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", fps=3, video_metadata=metadata)[
self.input_name
]
encoded_videos_batched = video_processing(
video_inputs, return_tensors="pt", fps=3, video_metadata=batched_metadata
)[self.input_name]
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(
[video_inputs[0]], num_frames=6
)
expected_output_video_shape_batched = self.video_processor_tester.expected_output_video_shape(
video_inputs, num_frames=6
)
self.assertListEqual(list(encoded_videos.shape), expected_output_video_shape)
self.assertListEqual(list(encoded_videos_batched.shape), expected_output_video_shape_batched)
# We should raise error when asked to sample more frames than there are in input video
with self.assertRaises(ValueError):
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=10)[self.input_name]
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=10)[
self.input_name
]
# Assign back the actual num frames in tester
self.video_processor_tester.num_frames = prev_num_frames

View File

@ -16,6 +16,7 @@ import shutil
import tempfile
import unittest
from io import BytesIO
from typing import Optional
import numpy as np
import requests
@ -63,7 +64,7 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
)
cls.bos_token = processor.tokenizer.bos_token
cls.image_token = processor.image_token
cls.video_token = processor.image_token * 8 # SmolVLM uses image token and repeats it `num_frames` times
cls.video_token = processor.video_token
cls.fake_image_token = processor.fake_image_token
cls.global_img_token = processor.global_image_token
@ -93,6 +94,13 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
"chat_template": "<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
}
def prepare_video_inputs(self, batch_size: Optional[int] = None):
"""This function prepares a list of numpy videos."""
video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8
if batch_size is None:
return [[video_input]]
return [[video_input]] * batch_size
def get_split_image_expected_tokens(self, processor, image_rows, image_cols):
text_split_images = []
for n_h in range(image_rows):
@ -347,7 +355,6 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
{"type": "text", "text": "What do these images show?"},
{"type": "image"},
{"type": "image"},
"What do these images show?",
],
},
{
@ -373,11 +380,8 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
)
self.assertEqual(rendered, expected_rendered)
@unittest.skip(reason="SmolVLM replaced `type=video` with `type=image` in chat templates")
def test_apply_chat_template_video_special_processing(self):
pass
@require_av
@require_torch
def test_apply_chat_template_video_frame_sampling(self):
# overridden because SmolVLM has special preprocessing for videos
processor = self.get_processor()
@ -406,7 +410,7 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenize=True,
return_dict=True,
num_frames=num_frames,
return_tensors="np",
return_tensors="pt",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
@ -421,7 +425,7 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenize=True,
return_dict=True,
video_fps=video_fps,
return_tensors="np",
return_tensors="pt",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
@ -482,11 +486,11 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
do_rescale=True,
rescale_factor=-1,
padding="max_length",
max_length=76,
max_length=172,
)
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
self.assertEqual(len(inputs["input_ids"][0]), 76)
self.assertEqual(len(inputs["input_ids"][0]), 172)
@require_torch
@require_vision

View File

@ -15,22 +15,16 @@
import unittest
import numpy as np
from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs
if is_torch_available():
import torch
if is_vision_available():
if is_torchvision_available():
from transformers import SmolVLMVideoProcessor
from transformers.models.smolvlm.video_processing_smolvlm import get_resize_output_image_size
class SmolVLMVideoProcessingTester:
@ -58,6 +52,7 @@ class SmolVLMVideoProcessingTester:
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.max_image_size = size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
@ -71,17 +66,16 @@ class SmolVLMVideoProcessingTester:
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_convert_rgb": self.do_convert_rgb,
"max_image_size": self.max_image_size,
}
def expected_output_video_shape(self, videos):
max_height, max_width = 0, 0
if not isinstance(videos[0], torch.Tensor):
videos = [torch.tensor(np.array(video)).permute(0, -1, -3, -2) for video in videos]
for video in videos:
height, width = get_resize_output_image_size(video, self.size["longest_edge"])
max_height = max(height, max_height)
max_width = max(width, max_width)
return [self.num_frames, self.num_channels, max_height, max_width]
return [
self.num_frames,
self.num_channels,
self.max_image_size["longest_edge"],
self.max_image_size["longest_edge"],
]
def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"):
videos = prepare_video_inputs(
@ -116,3 +110,58 @@ class SmolVLMVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict, size=42)
self.assertEqual(video_processor.size, {"height": 42, "width": 42})
# overwrite, SmolVLM requires to have metadata no matter how we sample
def test_call_sample_frames(self):
for video_processing_class in self.video_processor_list:
video_processing = video_processing_class(**self.video_processor_dict)
prev_num_frames = self.video_processor_tester.num_frames
self.video_processor_tester.num_frames = 8
video_inputs = self.video_processor_tester.prepare_video_inputs(
equal_resolution=False,
return_tensors="torch",
)
# Force set sampling to False. No sampling is expected even when `num_frames` exists
video_processing.do_sample_frames = False
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=3)[self.input_name]
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=3)[self.input_name]
self.assertEqual(encoded_videos.shape[1], 8)
self.assertEqual(encoded_videos_batched.shape[1], 8)
# Set sampling to True. Video frames should be sampled with `num_frames` in the output
video_processing.do_sample_frames = True
metadata = [[{"duration": 2.0, "total_num_frames": 8, "fps": 4}]]
batched_metadata = metadata * len(video_inputs)
# Sample with `fps` requires metadata to infer number of frames from total duration
with self.assertRaises(ValueError):
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=6, fps=3)[
self.input_name
]
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=6, fps=3)[
self.input_name
]
encoded_videos = video_processing(
video_inputs[0], return_tensors="pt", num_frames=6, fps=3, video_metadata=metadata
)[self.input_name]
encoded_videos_batched = video_processing(
video_inputs, return_tensors="pt", num_frames=6, fps=3, video_metadata=batched_metadata
)[self.input_name]
self.assertEqual(encoded_videos.shape[1], 6)
self.assertEqual(encoded_videos_batched.shape[1], 6)
# We should raise error when asked to sample more frames than there are in input video
with self.assertRaises(ValueError):
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", fps=10, num_frames=20)[
self.input_name
]
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", fps=10, num_frames=20)[
self.input_name
]
# Assign back the actual num frames in tester
self.video_processor_tester.num_frames = prev_num_frames

View File

@ -507,7 +507,7 @@ class ProcessorTesterMixin:
if "video_processor" not in self.processor_class.attributes:
self.skipTest(f"video_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=167, padding="max_length")
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(**processor_components, **processor_kwargs)
@ -515,7 +515,7 @@ class ProcessorTesterMixin:
input_str = self.prepare_text_inputs(modality="video")
video_input = self.prepare_video_inputs()
inputs = processor(text=input_str, videos=video_input, return_tensors="pt")
self.assertEqual(inputs[self.text_input_name].shape[-1], 117)
self.assertEqual(inputs[self.text_input_name].shape[-1], 167)
def test_video_processor_defaults_preserved_by_video_kwargs(self):
"""
@ -529,7 +529,7 @@ class ProcessorTesterMixin:
processor_components["video_processor"] = self.get_component(
"video_processor", do_rescale=True, rescale_factor=-1
)
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=167, padding="max_length")
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(**processor_components, **processor_kwargs)
@ -553,9 +553,9 @@ class ProcessorTesterMixin:
input_str = self.prepare_text_inputs(modality="video")
video_input = self.prepare_video_inputs()
inputs = processor(
text=input_str, videos=video_input, return_tensors="pt", max_length=112, padding="max_length"
text=input_str, videos=video_input, return_tensors="pt", max_length=162, padding="max_length"
)
self.assertEqual(inputs[self.text_input_name].shape[-1], 112)
self.assertEqual(inputs[self.text_input_name].shape[-1], 162)
def test_kwargs_overrides_default_video_processor_kwargs(self):
if "video_processor" not in self.processor_class.attributes:
@ -564,7 +564,7 @@ class ProcessorTesterMixin:
processor_components["video_processor"] = self.get_component(
"video_processor", do_rescale=True, rescale_factor=1
)
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=167, padding="max_length")
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(**processor_components, **processor_kwargs)
@ -593,11 +593,11 @@ class ProcessorTesterMixin:
do_rescale=True,
rescale_factor=-1,
padding="max_length",
max_length=76,
max_length=176,
)
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
self.assertEqual(inputs[self.text_input_name].shape[-1], 176)
def test_unstructured_kwargs_batched_video(self):
if "video_processor" not in self.processor_class.attributes:
@ -616,13 +616,13 @@ class ProcessorTesterMixin:
do_rescale=True,
rescale_factor=-1,
padding="longest",
max_length=76,
max_length=176,
)
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
self.assertTrue(
len(inputs[self.text_input_name][0]) == len(inputs[self.text_input_name][1])
and len(inputs[self.text_input_name][1]) < 76
and len(inputs[self.text_input_name][1]) < 176
)
def test_doubly_passed_kwargs_video(self):
@ -659,14 +659,14 @@ class ProcessorTesterMixin:
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"videos_kwargs": {"do_rescale": True, "rescale_factor": -1},
"text_kwargs": {"padding": "max_length", "max_length": 76},
"text_kwargs": {"padding": "max_length", "max_length": 176},
}
inputs = processor(text=input_str, videos=video_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
self.assertEqual(inputs[self.text_input_name].shape[-1], 176)
def test_structured_kwargs_nested_from_dict_video(self):
if "video_processor" not in self.processor_class.attributes:
@ -682,12 +682,12 @@ class ProcessorTesterMixin:
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"videos_kwargs": {"do_rescale": True, "rescale_factor": -1},
"text_kwargs": {"padding": "max_length", "max_length": 76},
"text_kwargs": {"padding": "max_length", "max_length": 176},
}
inputs = processor(text=input_str, videos=video_input, **all_kwargs)
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
self.assertEqual(inputs[self.text_input_name].shape[-1], 176)
# TODO: the same test, but for audio + text processors that have strong overlap in kwargs
# TODO (molbap) use the same structure of attribute kwargs for other tests to avoid duplication
@ -884,7 +884,7 @@ class ProcessorTesterMixin:
tokenize=True,
return_dict=True,
return_tensors=return_tensors,
num_frames=4, # by default no more than 4 frames, otherwise too slow
num_frames=2, # by default no more than 2 frames, otherwise too slow
)
input_name = getattr(self, input_name)
self.assertTrue(input_name in out_dict)
@ -983,6 +983,21 @@ class ProcessorTesterMixin:
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), video_fps * 10)
# Whan `do_sample_frames=False` no sampling is done and whole video is loaded, even if number of frames is passed
video_fps = 1
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
do_sample_frames=False,
video_fps=video_fps,
return_tensors="pt",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 300)
# Load with `video_fps` and `num_frames` args, should raise an error
with self.assertRaises(ValueError):
out_dict_with_video = processor.apply_chat_template(
@ -1024,75 +1039,6 @@ class ProcessorTesterMixin:
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
@require_av
@require_torch
def test_apply_chat_template_video_special_processing(self):
"""
Tests that models can use their own preprocessing to preprocess conversations.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
messages = [
[
{
"role": "user",
"content": [
{"type": "video", "path": video_file_path},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
def _process_messages_for_chat_template(
conversation,
batch_images,
batch_videos,
batch_video_metadata,
**chat_template_kwargs,
):
# Let us just always return a dummy prompt
new_msg = [
[
{
"role": "user",
"content": [
{"type": "video"}, # no need to use path, video is loaded already by this moment
{"type": "text", "text": "Dummy prompt for preprocess testing"},
],
},
]
]
return new_msg
processor._process_messages_for_chat_template = _process_messages_for_chat_template
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243)
@require_librosa
@require_av
def test_chat_template_audio_from_video(self):

View File

@ -293,6 +293,59 @@ class VideoProcessingTestMixin:
(self.video_processor_tester.batch_size, *expected_output_video_shape),
)
def test_call_sample_frames(self):
for video_processing_class in self.video_processor_list:
video_processing = video_processing_class(**self.video_processor_dict)
prev_num_frames = self.video_processor_tester.num_frames
self.video_processor_tester.num_frames = 8
video_inputs = self.video_processor_tester.prepare_video_inputs(
equal_resolution=False,
return_tensors="torch",
)
# Force set sampling to False. No sampling is expected even when `num_frames` exists
video_processing.do_sample_frames = False
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=3)[self.input_name]
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=3)[self.input_name]
self.assertEqual(encoded_videos.shape[1], 8)
self.assertEqual(encoded_videos_batched.shape[1], 8)
# Set sampling to True. Video frames should be sampled with `num_frames` in the output
video_processing.do_sample_frames = True
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=3)[self.input_name]
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=3)[self.input_name]
self.assertEqual(encoded_videos.shape[1], 3)
self.assertEqual(encoded_videos_batched.shape[1], 3)
# Sample with `fps` requires metadata to infer number of frames from total duration
with self.assertRaises(ValueError):
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", fps=3)[self.input_name]
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", fps=3)[self.input_name]
metadata = [[{"duration": 2.0, "total_num_frames": 8, "fps": 4}]]
batched_metadata = metadata * len(video_inputs)
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", fps=3, video_metadata=metadata)[
self.input_name
]
encoded_videos_batched = video_processing(
video_inputs, return_tensors="pt", fps=3, video_metadata=batched_metadata
)[self.input_name]
self.assertEqual(encoded_videos.shape[1], 6)
self.assertEqual(encoded_videos_batched.shape[1], 6)
# We should raise error when asked to sample more frames than there are in input video
with self.assertRaises(ValueError):
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=10)[self.input_name]
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=10)[
self.input_name
]
# Assign back the actual num frames in tester
self.video_processor_tester.num_frames = prev_num_frames
def test_nested_input(self):
"""Tests that the processor can work with nested list where each video is a list of arrays"""
for video_processing_class in self.video_processor_list: