diff --git a/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py index 5616e8f3e99..f120006d40d 100644 --- a/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/video_processing_instructblipvideo.py @@ -35,7 +35,7 @@ from ...utils import ( ) from ...utils.import_utils import requires from ...video_processing_utils import BaseVideoProcessor -from ...video_utils import group_videos_by_shape, reorder_videos +from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos if is_vision_available(): @@ -66,6 +66,7 @@ class InstructBlipVideoVideoProcessor(BaseVideoProcessor): do_rescale = True do_normalize = True do_convert_rgb = True + do_sample_frames = False # Set to False for BC, recommended to set `True` in new models valid_kwargs = InstructBlipVideoVideoProcessorInitKwargs model_input_names = ["pixel_values"] @@ -75,6 +76,7 @@ class InstructBlipVideoVideoProcessor(BaseVideoProcessor): def _preprocess( self, videos: List["torch.Tensor"], + video_metadata: Union[List[VideoMetadata], List[dict]], do_convert_rgb: bool, do_resize: bool, size: SizeDict, @@ -86,10 +88,18 @@ class InstructBlipVideoVideoProcessor(BaseVideoProcessor): do_pad: bool, rescale_factor: float, do_normalize: bool, + do_sample_frames: bool, image_mean: Optional[Union[float, List[float]]], image_std: Optional[Union[float, List[float]]], - return_tensors: Optional[Union[str, TensorType]], + fps: Optional[int] = None, + num_frames: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, ) -> BatchFeature: + if do_sample_frames: + videos = [ + self.sample_frames(video, metadata, num_frames, fps) for video, metadata in zip(videos, video_metadata) + ] + # Group videos by size for batched resizing grouped_videos, grouped_videos_index = group_videos_by_shape(videos) resized_videos_grouped = {} diff --git a/src/transformers/models/internvl/processing_internvl.py b/src/transformers/models/internvl/processing_internvl.py index c9a8c2028d5..bd47a231e5c 100644 --- a/src/transformers/models/internvl/processing_internvl.py +++ b/src/transformers/models/internvl/processing_internvl.py @@ -21,7 +21,7 @@ from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput, concatenate_list, make_flat_list_of_images from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...video_utils import VideoInput, VideoMetadata, load_video, make_batched_videos +from ...video_utils import VideoInput, make_batched_videos class InternVLImagesKwargs(ImagesKwargs, total=False): @@ -290,32 +290,6 @@ class InternVLProcessor(ProcessorMixin): return MultiModalData(**vision_data) - def sample_indices_fn( - self, metadata: VideoMetadata, num_frames: Optional[int] = None, initial_shift: Union[bool, float, int] = True - ): - """ - The function to generate indices of frames to sample from a video. - - Args: - metadata (`VideoMetadata`): - `VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps". - num_frames (`int`, *optional*): - Number of frames to sample uniformly. If None, all frames are sampled. - initial_shift (`bool`, `float` or `int`, defaults to `0`): - The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video. - - Returns: - `np.ndarray`: Array of frame indices to sample. - """ - num_frames = num_frames if num_frames is not None else metadata.total_num_frames - - if initial_shift is True: - initial_shift = metadata.total_num_frames / num_frames / 2 - indices = np.arange(initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames).astype( - int - ) - return indices - def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please @@ -336,39 +310,5 @@ class InternVLProcessor(ProcessorMixin): image_processor_input_names = self.image_processor.model_input_names return list(tokenizer_input_names) + list(image_processor_input_names) - # TODO: raushan, has to be public method under `VideoProcessorBase` when API is added - def _load_video_for_model( - self, - video: Union[str, "VideoInput"], - num_frames: Optional[int], - backend: str = "pyav", - initial_shift: bool = True, - **kwargs, - ) -> np.array: - """ - Loads `video` to a numpy array. - - Args: - video (`str` or `VideoInput`): - The video to convert to the numpy array format. Can be a link to video or local path. - num_frames (`int`, *optional*): - Number of frames to sample uniformly. If not passed, the whole video is loaded. - backend (`str`, *optional*, defaults to `"pyav"`): - The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav". - initial_shift (`bool`, *optional*, defaults to `True`): - The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video. - - Returns: - Tuple[`np.array`, Dict]: A tuple containing: - - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - - Metadata dictionary. - """ - - def sample_indices_fn_func(metadata, **fn_kwargs): - return self.sample_indices_fn(metadata, num_frames=num_frames, initial_shift=initial_shift, **fn_kwargs) - - video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func) - return video, metadata - __all__ = ["InternVLProcessor"] diff --git a/src/transformers/models/internvl/video_processing_internvl.py b/src/transformers/models/internvl/video_processing_internvl.py index e1d17b1b0ce..74f5981af95 100644 --- a/src/transformers/models/internvl/video_processing_internvl.py +++ b/src/transformers/models/internvl/video_processing_internvl.py @@ -14,25 +14,43 @@ # limitations under the License. """Fast Video processor class for InternVL.""" +from typing import List, Optional, Union + +from ...image_processing_utils import BatchFeature from ...image_utils import ( OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, + SizeDict, ) from ...processing_utils import Unpack, VideosKwargs from ...utils import ( + TensorType, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, is_vision_available, ) from ...utils.import_utils import requires -from ...video_processing_utils import ( - BaseVideoProcessor, -) +from ...video_processing_utils import BaseVideoProcessor +from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +if is_torch_available(): + import torch + if is_vision_available(): from ...image_utils import PILImageResampling -class InternVLVideoProcessorInitKwargs(VideosKwargs): ... +class InternVLVideoProcessorInitKwargs(VideosKwargs): + initial_shift: Union[bool, float, int] @requires(backends=("torchvision",)) @@ -45,11 +63,128 @@ class InternVLVideoProcessor(BaseVideoProcessor): do_rescale = True do_normalize = True do_convert_rgb = True + initial_shift = True + do_sample_frames = False # Set to False for BC, recommended to set `True` in new models valid_kwargs = InternVLVideoProcessorInitKwargs model_input_names = ["pixel_values_videos"] def __init__(self, **kwargs: Unpack[InternVLVideoProcessorInitKwargs]): super().__init__(**kwargs) + def sample_frames( + self, + video: "torch.Tensor", + metadata: Optional[Union[VideoMetadata, dict]] = None, + num_frames: Optional[int] = None, + fps: Optional[int] = None, + initial_shift: Optional[Union[bool, float, int]] = None, + ): + """ + Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames. + If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames` + and `fps` are mutually exclusive. + + Args: + video (`torch.Tensor`): + Video that need to be sampled. + metadata (`VideoMetadata`, *optional*): + Metadata of the video containing information about total duration, fps and total number of frames. + num_frames (`int`, *optional*): + Maximum number of frames to sample. Defaults to `self.num_frames`. + fps (`int`, *optional*): + Target frames to sample per second. Defaults to `self.fps`. + initial_shift (`bool`, `float` or `int`, defaults to `self.initial_shift`): + The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video. + + Returns: + torch.Tensor: + Sampled video frames. + """ + num_frames = num_frames if num_frames is not None else self.num_frames + initial_shift = initial_shift if initial_shift is not None else self.initial_shift + total_num_frames = video.shape[0] + + # If num_frames is not given but fps is, calculate num_frames from fps + if num_frames is None and fps is not None: + if metadata is None: + raise ValueError( + "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. " + "Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video" + ) + num_frames = int(total_num_frames / metadata["fps"] * fps) + + if initial_shift is True: + initial_shift = total_num_frames / num_frames / 2 + + if num_frames > total_num_frames: + raise ValueError( + f"Video can't be sampled. The `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. " + ) + + indices = torch.arange(initial_shift, total_num_frames, total_num_frames / num_frames).int() + video = video[indices].contiguous() + return video + + def _preprocess( + self, + videos: List["torch.Tensor"], + video_metadata: Union[List[VideoMetadata], List[dict]], + do_convert_rgb: bool, + do_resize: bool, + size: SizeDict, + size_divisor: Optional[int], + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + do_pad: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + do_sample_frames: Optional[bool] = None, + fps: Optional[int] = None, + num_frames: Optional[int] = None, + initial_shift: Optional[Union[bool, float, int]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + if do_sample_frames: + # Sample video frames + videos = [ + self.sample_frames(video, metadata, fps=fps, num_frames=num_frames, initial_shift=initial_shift) + for video, metadata in zip(videos, video_metadata) + ] + + # Group videos by size for batched resizing + grouped_videos, grouped_videos_index = group_videos_by_shape(videos) + resized_videos_grouped = {} + for shape, stacked_videos in grouped_videos.items(): + if do_convert_rgb: + stacked_videos = self.convert_to_rgb(stacked_videos) + if do_resize: + stacked_videos = self.resize( + stacked_videos, size=size, size_divisor=size_divisor, interpolation=interpolation + ) + resized_videos_grouped[shape] = stacked_videos + resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index) + + # Group videos by size for further processing + # Needed in case do_resize is False, or resize returns videos with different sizes + grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos) + processed_videos_grouped = {} + for shape, stacked_videos in grouped_videos.items(): + if do_center_crop: + stacked_videos = self.center_crop(stacked_videos, crop_size) + # Fused rescale and normalize + stacked_videos = self.rescale_and_normalize( + stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_videos_grouped[shape] = stacked_videos + + processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index) + processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos + + return BatchFeature(data={"pixel_values_videos": processed_videos}, tensor_type=return_tensors) + __all__ = ["InternVLVideoProcessor"] diff --git a/src/transformers/models/llava_next_video/video_processing_llava_next_video.py b/src/transformers/models/llava_next_video/video_processing_llava_next_video.py index 390028a0070..95cd79da655 100644 --- a/src/transformers/models/llava_next_video/video_processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/video_processing_llava_next_video.py @@ -46,6 +46,7 @@ class LlavaNextVideoVideoProcessor(BaseVideoProcessor): do_rescale = True do_normalize = True do_convert_rgb = True + do_sample_frames = False # Set to False for BC, recommended to set `True` in new models valid_kwargs = LlavaNextVideoFastVideoProcessorInitKwargs model_input_names = ["pixel_values_videos"] diff --git a/src/transformers/models/llava_onevision/video_processing_llava_onevision.py b/src/transformers/models/llava_onevision/video_processing_llava_onevision.py index cee54a357e0..3972f424a94 100644 --- a/src/transformers/models/llava_onevision/video_processing_llava_onevision.py +++ b/src/transformers/models/llava_onevision/video_processing_llava_onevision.py @@ -47,6 +47,7 @@ class LlavaOnevisionVideoProcessor(BaseVideoProcessor): do_rescale = True do_normalize = True do_convert_rgb = True + do_sample_frames = False # Set to False for BC, recommended to set `True` in new models valid_kwargs = LlavaOnevisionFastVideoProcessorInitKwargs model_input_names = ["pixel_values_videos"] diff --git a/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py index ea449184705..e211b8d911f 100644 --- a/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py @@ -154,7 +154,7 @@ class Qwen2_5OmniProcessor(ProcessorMixin): seconds_per_chunk = output_kwargs["videos_kwargs"].pop("seconds_per_chunk") position_id_per_seconds = output_kwargs["videos_kwargs"].pop("position_id_per_seconds") use_audio_in_video = output_kwargs["videos_kwargs"].pop("use_audio_in_video") - fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + fps = output_kwargs["videos_kwargs"].get("fps", 2.0) if audio is not None: output_kwargs["audio_kwargs"]["padding"] = "max_length" # Support "max_length" padding only here diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index f293f5c769c..e4685acc644 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -928,7 +928,6 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): "padding": False, "return_mm_token_type_ids": False, }, - "videos_kwargs": {"fps": 2.0}, } @@ -1013,9 +1012,7 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor): image_grid_thw = image_inputs["image_grid_thw"] if videos is not None: - # pop fps in advance for passing kwargs validation - fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) - + fps = output_kwargs["videos_kwargs"].get("fps", 2.0) videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] diff --git a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py index f835390a079..e145791eea9 100644 --- a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -54,7 +54,6 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): "padding": False, "return_mm_token_type_ids": False, }, - "videos_kwargs": {"fps": 2.0}, } @@ -151,9 +150,7 @@ class Qwen2_5_VLProcessor(ProcessorMixin): image_grid_thw = image_inputs["image_grid_thw"] if videos is not None: - # pop fps in advance for passing kwargs validation - fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) - + fps = output_kwargs["videos_kwargs"].get("fps", 2.0) videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] diff --git a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py index 991459887b2..49a4e9d2efc 100644 --- a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py @@ -19,6 +19,7 @@ # limitations under the License. """video processor class for Qwen2-VL.""" +import math from typing import List, Optional, Union from ...image_processing_utils import ( @@ -45,7 +46,7 @@ from ...video_processing_utils import ( BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor, ) -from ...video_utils import group_videos_by_shape, reorder_videos +from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos if is_vision_available(): @@ -69,6 +70,8 @@ class Qwen2VLVideoProcessorInitKwargs(VideosKwargs): patch_size: Optional[int] temporal_patch_size: Optional[int] merge_size: Optional[int] + min_frames: Optional[int] + max_frames: Optional[int] @add_start_docstrings( @@ -85,23 +88,30 @@ class Qwen2VLVideoProcessorInitKwargs(VideosKwargs): The temporal patch size of the vision encoder. merge_size (`int`, *optional*, defaults to 2): The merge size of the vision encoder to llm encoder. + min_frames (`int`, *optional*, defaults to 4): + The minimum number of frames that can be sampled. + max_frames (`int`, *optional*, defaults to 768): + The maximum number of frames that can be sampled. """, ) @requires(backends=("torchvision",)) class Qwen2VLVideoProcessor(BaseVideoProcessor): resample = PILImageResampling.BICUBIC - size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280} + size = {"shortest_edge": 128 * 28 * 28, "longest_edge": 28 * 28 * 768} image_mean = OPENAI_CLIP_MEAN image_std = OPENAI_CLIP_STD do_resize = True do_rescale = True do_normalize = True do_convert_rgb = True - min_pixels = 56 * 56 - max_pixels = 28 * 28 * 1280 + min_pixels = 128 * 28 * 28 + max_pixels = 28 * 28 * 768 patch_size = 14 temporal_patch_size = 2 merge_size = 2 + min_frames = 4 + max_frames = 768 + do_sample_frames = False # Set to False for BC, recommended to set `True` in new models valid_kwargs = Qwen2VLVideoProcessorInitKwargs model_input_names = ["pixel_values_videos", "video_grid_thw"] @@ -109,9 +119,80 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor): super().__init__(**kwargs) self.size = {"shortest_edge": self.min_pixels, "longest_edge": self.max_pixels} + def sample_frames( + self, + video: "torch.Tensor", + frame_factor: int, + min_frames: int, + max_frames: int, + metadata: Optional[Union[VideoMetadata, dict]] = None, + num_frames: Optional[int] = None, + fps: Optional[int] = None, + ): + """ + Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames. + If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames` + and `fps` are mutually exclusive. + + Args: + video (`torch.Tensor`): + Video that need to be sampled. + frame_factor (`int`): + The temporal patch size of the vision encoder. Number of sampled frames will be rounded to be divisible by frame factor. + min_frames (`int`): + The minimum number of frames that can be sampled. + max_frames (`int`): + The maximum number of frames that can be sampled. + metadata (`VideoMetadata`, *optional*): + Metadata of the video containing information about total duration, fps and total number of frames. + num_frames (`int`, *optional*): + Maximum number of frames to sample. Defaults to `self.num_frames`. + fps (`int`, *optional*): + Target frames to sample per second. Defaults to `self.fps`. + + Returns: + torch.Tensor: + Sampled video frames. + """ + if fps is not None and num_frames is not None: + raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!") + + num_frames = num_frames if num_frames is not None else self.num_frames + fps = fps if fps is not None else self.fps + total_num_frames = video.shape[0] + + # If num_frames is not given but fps is, calculate num_frames from fps + if num_frames is not None: + num_frames = round(num_frames / frame_factor) * frame_factor + elif fps is not None: + if metadata is None: + raise ValueError( + "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. " + "Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video" + ) + max_frames = math.floor(min(max_frames, total_num_frames) / frame_factor) * frame_factor + num_frames = total_num_frames / metadata["fps"] * fps + num_frames = min(min(max(num_frames, min_frames), max_frames), total_num_frames) + num_frames = math.floor(num_frames / frame_factor) * frame_factor + + if num_frames > total_num_frames: + raise ValueError( + f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. " + "Decrease `num_frames` or `fps` for sampling." + ) + + if num_frames is not None: + indices = torch.arange(0, total_num_frames, total_num_frames / num_frames).int() + else: + indices = torch.arange(0, total_num_frames).int() + video = video[indices].contiguous() + + return video + def _preprocess( self, videos: List["torch.Tensor"], + video_metadata: Union[List[VideoMetadata], List[dict]], do_convert_rgb: bool, do_resize: bool, size: SizeDict, @@ -119,6 +200,7 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor): do_rescale: bool, rescale_factor: float, do_normalize: bool, + do_sample_frames: bool, image_mean: Optional[Union[float, List[float]]], image_std: Optional[Union[float, List[float]]], min_pixels: Optional[int] = None, @@ -126,9 +208,28 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor): patch_size: Optional[int] = None, temporal_patch_size: Optional[int] = None, merge_size: Optional[int] = None, + fps: Optional[int] = None, + num_frames: Optional[int] = None, + min_frames: Optional[int] = None, + max_frames: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs, ): + if do_sample_frames: + # Sample video frames + videos = [ + self.sample_frames( + video, + frame_factor=temporal_patch_size, + min_frames=min_frames, + max_frames=max_frames, + metadata=metadata, + num_frames=num_frames, + fps=fps, + ) + for video, metadata in zip(videos, video_metadata) + ] + # Group videos by size for batched resizing grouped_videos, grouped_videos_index = group_videos_by_shape(videos) resized_videos_grouped = {} diff --git a/src/transformers/models/smolvlm/processing_smolvlm.py b/src/transformers/models/smolvlm/processing_smolvlm.py index a440a0f29b1..8613a2f88c6 100644 --- a/src/transformers/models/smolvlm/processing_smolvlm.py +++ b/src/transformers/models/smolvlm/processing_smolvlm.py @@ -16,18 +16,15 @@ Processor class for SmolVLM. """ -import copy from datetime import timedelta from typing import TYPE_CHECKING, Dict, List, Optional, Union -import numpy as np - from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, make_nested_list_of_images -from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...processing_utils import AllKwargsForChatTemplate, ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import BatchEncoding, TextInput from ...utils import is_num2words_available, is_vision_available, logging -from ...video_utils import VideoInput, load_video, make_batched_videos +from ...video_utils import VideoInput if is_vision_available(): @@ -35,7 +32,13 @@ if is_vision_available(): DEFAULT_MEDIA_OUTTRO, DEFAULT_VIDEO_INTRO, FRAME_TIMESTAMP_MESSAGE, - smolvlm_sample_indices_fn, + ) + +if is_vision_available(): + from .video_processing_smolvlm import ( + DEFAULT_MEDIA_OUTTRO, + DEFAULT_VIDEO_INTRO, + FRAME_TIMESTAMP_MESSAGE, ) if TYPE_CHECKING: @@ -50,6 +53,10 @@ else: num2words = None +# The correct chat template to be used for videos after #38105 +DEFAULT_CHAT_TEMPLATE = "<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '' }}{% elif line['type'] == 'video' %}{{ '