mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
[video processors] support frame sampling within processors (#38105)
* apply updates smolVLM (still needs workaround for chat template) * add other models * dump qwen omni for now, come back later * port qwen omni from their impl * wait, all qwens sample videos in same way! * clean up * make smolvlm backwards compatible and fix padding * dix some tests * fox smolvlm tests * more clean up and test fixing * delete unused arg * fix * address comments * style * fix test
This commit is contained in:
parent
887054c714
commit
27459025b8
@ -35,7 +35,7 @@ from ...utils import (
|
|||||||
)
|
)
|
||||||
from ...utils.import_utils import requires
|
from ...utils.import_utils import requires
|
||||||
from ...video_processing_utils import BaseVideoProcessor
|
from ...video_processing_utils import BaseVideoProcessor
|
||||||
from ...video_utils import group_videos_by_shape, reorder_videos
|
from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@ -66,6 +66,7 @@ class InstructBlipVideoVideoProcessor(BaseVideoProcessor):
|
|||||||
do_rescale = True
|
do_rescale = True
|
||||||
do_normalize = True
|
do_normalize = True
|
||||||
do_convert_rgb = True
|
do_convert_rgb = True
|
||||||
|
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
|
||||||
valid_kwargs = InstructBlipVideoVideoProcessorInitKwargs
|
valid_kwargs = InstructBlipVideoVideoProcessorInitKwargs
|
||||||
model_input_names = ["pixel_values"]
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
@ -75,6 +76,7 @@ class InstructBlipVideoVideoProcessor(BaseVideoProcessor):
|
|||||||
def _preprocess(
|
def _preprocess(
|
||||||
self,
|
self,
|
||||||
videos: List["torch.Tensor"],
|
videos: List["torch.Tensor"],
|
||||||
|
video_metadata: Union[List[VideoMetadata], List[dict]],
|
||||||
do_convert_rgb: bool,
|
do_convert_rgb: bool,
|
||||||
do_resize: bool,
|
do_resize: bool,
|
||||||
size: SizeDict,
|
size: SizeDict,
|
||||||
@ -86,10 +88,18 @@ class InstructBlipVideoVideoProcessor(BaseVideoProcessor):
|
|||||||
do_pad: bool,
|
do_pad: bool,
|
||||||
rescale_factor: float,
|
rescale_factor: float,
|
||||||
do_normalize: bool,
|
do_normalize: bool,
|
||||||
|
do_sample_frames: bool,
|
||||||
image_mean: Optional[Union[float, List[float]]],
|
image_mean: Optional[Union[float, List[float]]],
|
||||||
image_std: Optional[Union[float, List[float]]],
|
image_std: Optional[Union[float, List[float]]],
|
||||||
return_tensors: Optional[Union[str, TensorType]],
|
fps: Optional[int] = None,
|
||||||
|
num_frames: Optional[int] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
|
if do_sample_frames:
|
||||||
|
videos = [
|
||||||
|
self.sample_frames(video, metadata, num_frames, fps) for video, metadata in zip(videos, video_metadata)
|
||||||
|
]
|
||||||
|
|
||||||
# Group videos by size for batched resizing
|
# Group videos by size for batched resizing
|
||||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||||
resized_videos_grouped = {}
|
resized_videos_grouped = {}
|
||||||
|
@ -21,7 +21,7 @@ from ...image_processing_utils import BatchFeature
|
|||||||
from ...image_utils import ImageInput, concatenate_list, make_flat_list_of_images
|
from ...image_utils import ImageInput, concatenate_list, make_flat_list_of_images
|
||||||
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...video_utils import VideoInput, VideoMetadata, load_video, make_batched_videos
|
from ...video_utils import VideoInput, make_batched_videos
|
||||||
|
|
||||||
|
|
||||||
class InternVLImagesKwargs(ImagesKwargs, total=False):
|
class InternVLImagesKwargs(ImagesKwargs, total=False):
|
||||||
@ -290,32 +290,6 @@ class InternVLProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
return MultiModalData(**vision_data)
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
def sample_indices_fn(
|
|
||||||
self, metadata: VideoMetadata, num_frames: Optional[int] = None, initial_shift: Union[bool, float, int] = True
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
The function to generate indices of frames to sample from a video.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metadata (`VideoMetadata`):
|
|
||||||
`VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
|
|
||||||
num_frames (`int`, *optional*):
|
|
||||||
Number of frames to sample uniformly. If None, all frames are sampled.
|
|
||||||
initial_shift (`bool`, `float` or `int`, defaults to `0`):
|
|
||||||
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`np.ndarray`: Array of frame indices to sample.
|
|
||||||
"""
|
|
||||||
num_frames = num_frames if num_frames is not None else metadata.total_num_frames
|
|
||||||
|
|
||||||
if initial_shift is True:
|
|
||||||
initial_shift = metadata.total_num_frames / num_frames / 2
|
|
||||||
indices = np.arange(initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames).astype(
|
|
||||||
int
|
|
||||||
)
|
|
||||||
return indices
|
|
||||||
|
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
@ -336,39 +310,5 @@ class InternVLProcessor(ProcessorMixin):
|
|||||||
image_processor_input_names = self.image_processor.model_input_names
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
return list(tokenizer_input_names) + list(image_processor_input_names)
|
return list(tokenizer_input_names) + list(image_processor_input_names)
|
||||||
|
|
||||||
# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
|
|
||||||
def _load_video_for_model(
|
|
||||||
self,
|
|
||||||
video: Union[str, "VideoInput"],
|
|
||||||
num_frames: Optional[int],
|
|
||||||
backend: str = "pyav",
|
|
||||||
initial_shift: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
) -> np.array:
|
|
||||||
"""
|
|
||||||
Loads `video` to a numpy array.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
video (`str` or `VideoInput`):
|
|
||||||
The video to convert to the numpy array format. Can be a link to video or local path.
|
|
||||||
num_frames (`int`, *optional*):
|
|
||||||
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
|
||||||
backend (`str`, *optional*, defaults to `"pyav"`):
|
|
||||||
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav".
|
|
||||||
initial_shift (`bool`, *optional*, defaults to `True`):
|
|
||||||
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[`np.array`, Dict]: A tuple containing:
|
|
||||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
|
||||||
- Metadata dictionary.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def sample_indices_fn_func(metadata, **fn_kwargs):
|
|
||||||
return self.sample_indices_fn(metadata, num_frames=num_frames, initial_shift=initial_shift, **fn_kwargs)
|
|
||||||
|
|
||||||
video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func)
|
|
||||||
return video, metadata
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["InternVLProcessor"]
|
__all__ = ["InternVLProcessor"]
|
||||||
|
@ -14,25 +14,43 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Fast Video processor class for InternVL."""
|
"""Fast Video processor class for InternVL."""
|
||||||
|
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from ...image_processing_utils import BatchFeature
|
||||||
from ...image_utils import (
|
from ...image_utils import (
|
||||||
OPENAI_CLIP_MEAN,
|
OPENAI_CLIP_MEAN,
|
||||||
OPENAI_CLIP_STD,
|
OPENAI_CLIP_STD,
|
||||||
|
SizeDict,
|
||||||
)
|
)
|
||||||
from ...processing_utils import Unpack, VideosKwargs
|
from ...processing_utils import Unpack, VideosKwargs
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
|
TensorType,
|
||||||
|
is_torch_available,
|
||||||
|
is_torchvision_available,
|
||||||
|
is_torchvision_v2_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from ...utils.import_utils import requires
|
from ...utils.import_utils import requires
|
||||||
from ...video_processing_utils import (
|
from ...video_processing_utils import BaseVideoProcessor
|
||||||
BaseVideoProcessor,
|
from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
if is_torchvision_available():
|
||||||
|
if is_torchvision_v2_available():
|
||||||
|
from torchvision.transforms.v2 import functional as F
|
||||||
|
else:
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from ...image_utils import PILImageResampling
|
from ...image_utils import PILImageResampling
|
||||||
|
|
||||||
|
|
||||||
class InternVLVideoProcessorInitKwargs(VideosKwargs): ...
|
class InternVLVideoProcessorInitKwargs(VideosKwargs):
|
||||||
|
initial_shift: Union[bool, float, int]
|
||||||
|
|
||||||
|
|
||||||
@requires(backends=("torchvision",))
|
@requires(backends=("torchvision",))
|
||||||
@ -45,11 +63,128 @@ class InternVLVideoProcessor(BaseVideoProcessor):
|
|||||||
do_rescale = True
|
do_rescale = True
|
||||||
do_normalize = True
|
do_normalize = True
|
||||||
do_convert_rgb = True
|
do_convert_rgb = True
|
||||||
|
initial_shift = True
|
||||||
|
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
|
||||||
valid_kwargs = InternVLVideoProcessorInitKwargs
|
valid_kwargs = InternVLVideoProcessorInitKwargs
|
||||||
model_input_names = ["pixel_values_videos"]
|
model_input_names = ["pixel_values_videos"]
|
||||||
|
|
||||||
def __init__(self, **kwargs: Unpack[InternVLVideoProcessorInitKwargs]):
|
def __init__(self, **kwargs: Unpack[InternVLVideoProcessorInitKwargs]):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def sample_frames(
|
||||||
|
self,
|
||||||
|
video: "torch.Tensor",
|
||||||
|
metadata: Optional[Union[VideoMetadata, dict]] = None,
|
||||||
|
num_frames: Optional[int] = None,
|
||||||
|
fps: Optional[int] = None,
|
||||||
|
initial_shift: Optional[Union[bool, float, int]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames.
|
||||||
|
If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames`
|
||||||
|
and `fps` are mutually exclusive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video (`torch.Tensor`):
|
||||||
|
Video that need to be sampled.
|
||||||
|
metadata (`VideoMetadata`, *optional*):
|
||||||
|
Metadata of the video containing information about total duration, fps and total number of frames.
|
||||||
|
num_frames (`int`, *optional*):
|
||||||
|
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
||||||
|
fps (`int`, *optional*):
|
||||||
|
Target frames to sample per second. Defaults to `self.fps`.
|
||||||
|
initial_shift (`bool`, `float` or `int`, defaults to `self.initial_shift`):
|
||||||
|
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor:
|
||||||
|
Sampled video frames.
|
||||||
|
"""
|
||||||
|
num_frames = num_frames if num_frames is not None else self.num_frames
|
||||||
|
initial_shift = initial_shift if initial_shift is not None else self.initial_shift
|
||||||
|
total_num_frames = video.shape[0]
|
||||||
|
|
||||||
|
# If num_frames is not given but fps is, calculate num_frames from fps
|
||||||
|
if num_frames is None and fps is not None:
|
||||||
|
if metadata is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
|
||||||
|
"Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
|
||||||
|
)
|
||||||
|
num_frames = int(total_num_frames / metadata["fps"] * fps)
|
||||||
|
|
||||||
|
if initial_shift is True:
|
||||||
|
initial_shift = total_num_frames / num_frames / 2
|
||||||
|
|
||||||
|
if num_frames > total_num_frames:
|
||||||
|
raise ValueError(
|
||||||
|
f"Video can't be sampled. The `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
|
||||||
|
)
|
||||||
|
|
||||||
|
indices = torch.arange(initial_shift, total_num_frames, total_num_frames / num_frames).int()
|
||||||
|
video = video[indices].contiguous()
|
||||||
|
return video
|
||||||
|
|
||||||
|
def _preprocess(
|
||||||
|
self,
|
||||||
|
videos: List["torch.Tensor"],
|
||||||
|
video_metadata: Union[List[VideoMetadata], List[dict]],
|
||||||
|
do_convert_rgb: bool,
|
||||||
|
do_resize: bool,
|
||||||
|
size: SizeDict,
|
||||||
|
size_divisor: Optional[int],
|
||||||
|
interpolation: Optional["F.InterpolationMode"],
|
||||||
|
do_center_crop: bool,
|
||||||
|
crop_size: SizeDict,
|
||||||
|
do_rescale: bool,
|
||||||
|
do_pad: bool,
|
||||||
|
rescale_factor: float,
|
||||||
|
do_normalize: bool,
|
||||||
|
image_mean: Optional[Union[float, List[float]]],
|
||||||
|
image_std: Optional[Union[float, List[float]]],
|
||||||
|
do_sample_frames: Optional[bool] = None,
|
||||||
|
fps: Optional[int] = None,
|
||||||
|
num_frames: Optional[int] = None,
|
||||||
|
initial_shift: Optional[Union[bool, float, int]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
) -> BatchFeature:
|
||||||
|
if do_sample_frames:
|
||||||
|
# Sample video frames
|
||||||
|
videos = [
|
||||||
|
self.sample_frames(video, metadata, fps=fps, num_frames=num_frames, initial_shift=initial_shift)
|
||||||
|
for video, metadata in zip(videos, video_metadata)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Group videos by size for batched resizing
|
||||||
|
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||||
|
resized_videos_grouped = {}
|
||||||
|
for shape, stacked_videos in grouped_videos.items():
|
||||||
|
if do_convert_rgb:
|
||||||
|
stacked_videos = self.convert_to_rgb(stacked_videos)
|
||||||
|
if do_resize:
|
||||||
|
stacked_videos = self.resize(
|
||||||
|
stacked_videos, size=size, size_divisor=size_divisor, interpolation=interpolation
|
||||||
|
)
|
||||||
|
resized_videos_grouped[shape] = stacked_videos
|
||||||
|
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
|
||||||
|
|
||||||
|
# Group videos by size for further processing
|
||||||
|
# Needed in case do_resize is False, or resize returns videos with different sizes
|
||||||
|
grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
|
||||||
|
processed_videos_grouped = {}
|
||||||
|
for shape, stacked_videos in grouped_videos.items():
|
||||||
|
if do_center_crop:
|
||||||
|
stacked_videos = self.center_crop(stacked_videos, crop_size)
|
||||||
|
# Fused rescale and normalize
|
||||||
|
stacked_videos = self.rescale_and_normalize(
|
||||||
|
stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||||
|
)
|
||||||
|
processed_videos_grouped[shape] = stacked_videos
|
||||||
|
|
||||||
|
processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
|
||||||
|
processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos
|
||||||
|
|
||||||
|
return BatchFeature(data={"pixel_values_videos": processed_videos}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["InternVLVideoProcessor"]
|
__all__ = ["InternVLVideoProcessor"]
|
||||||
|
@ -46,6 +46,7 @@ class LlavaNextVideoVideoProcessor(BaseVideoProcessor):
|
|||||||
do_rescale = True
|
do_rescale = True
|
||||||
do_normalize = True
|
do_normalize = True
|
||||||
do_convert_rgb = True
|
do_convert_rgb = True
|
||||||
|
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
|
||||||
valid_kwargs = LlavaNextVideoFastVideoProcessorInitKwargs
|
valid_kwargs = LlavaNextVideoFastVideoProcessorInitKwargs
|
||||||
model_input_names = ["pixel_values_videos"]
|
model_input_names = ["pixel_values_videos"]
|
||||||
|
|
||||||
|
@ -47,6 +47,7 @@ class LlavaOnevisionVideoProcessor(BaseVideoProcessor):
|
|||||||
do_rescale = True
|
do_rescale = True
|
||||||
do_normalize = True
|
do_normalize = True
|
||||||
do_convert_rgb = True
|
do_convert_rgb = True
|
||||||
|
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
|
||||||
valid_kwargs = LlavaOnevisionFastVideoProcessorInitKwargs
|
valid_kwargs = LlavaOnevisionFastVideoProcessorInitKwargs
|
||||||
model_input_names = ["pixel_values_videos"]
|
model_input_names = ["pixel_values_videos"]
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ class Qwen2_5OmniProcessor(ProcessorMixin):
|
|||||||
seconds_per_chunk = output_kwargs["videos_kwargs"].pop("seconds_per_chunk")
|
seconds_per_chunk = output_kwargs["videos_kwargs"].pop("seconds_per_chunk")
|
||||||
position_id_per_seconds = output_kwargs["videos_kwargs"].pop("position_id_per_seconds")
|
position_id_per_seconds = output_kwargs["videos_kwargs"].pop("position_id_per_seconds")
|
||||||
use_audio_in_video = output_kwargs["videos_kwargs"].pop("use_audio_in_video")
|
use_audio_in_video = output_kwargs["videos_kwargs"].pop("use_audio_in_video")
|
||||||
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
|
fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
|
||||||
|
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
output_kwargs["audio_kwargs"]["padding"] = "max_length" # Support "max_length" padding only here
|
output_kwargs["audio_kwargs"]["padding"] = "max_length" # Support "max_length" padding only here
|
||||||
|
@ -928,7 +928,6 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
"padding": False,
|
"padding": False,
|
||||||
"return_mm_token_type_ids": False,
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"videos_kwargs": {"fps": 2.0},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1013,9 +1012,7 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor):
|
|||||||
image_grid_thw = image_inputs["image_grid_thw"]
|
image_grid_thw = image_inputs["image_grid_thw"]
|
||||||
|
|
||||||
if videos is not None:
|
if videos is not None:
|
||||||
# pop fps in advance for passing kwargs validation
|
fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
|
||||||
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
|
|
||||||
|
|
||||||
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||||
|
|
||||||
|
@ -54,7 +54,6 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
"padding": False,
|
"padding": False,
|
||||||
"return_mm_token_type_ids": False,
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"videos_kwargs": {"fps": 2.0},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -151,9 +150,7 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
|
|||||||
image_grid_thw = image_inputs["image_grid_thw"]
|
image_grid_thw = image_inputs["image_grid_thw"]
|
||||||
|
|
||||||
if videos is not None:
|
if videos is not None:
|
||||||
# pop fps in advance for passing kwargs validation
|
fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
|
||||||
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
|
|
||||||
|
|
||||||
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""video processor class for Qwen2-VL."""
|
"""video processor class for Qwen2-VL."""
|
||||||
|
|
||||||
|
import math
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from ...image_processing_utils import (
|
from ...image_processing_utils import (
|
||||||
@ -45,7 +46,7 @@ from ...video_processing_utils import (
|
|||||||
BASE_VIDEO_PROCESSOR_DOCSTRING,
|
BASE_VIDEO_PROCESSOR_DOCSTRING,
|
||||||
BaseVideoProcessor,
|
BaseVideoProcessor,
|
||||||
)
|
)
|
||||||
from ...video_utils import group_videos_by_shape, reorder_videos
|
from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@ -69,6 +70,8 @@ class Qwen2VLVideoProcessorInitKwargs(VideosKwargs):
|
|||||||
patch_size: Optional[int]
|
patch_size: Optional[int]
|
||||||
temporal_patch_size: Optional[int]
|
temporal_patch_size: Optional[int]
|
||||||
merge_size: Optional[int]
|
merge_size: Optional[int]
|
||||||
|
min_frames: Optional[int]
|
||||||
|
max_frames: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@ -85,23 +88,30 @@ class Qwen2VLVideoProcessorInitKwargs(VideosKwargs):
|
|||||||
The temporal patch size of the vision encoder.
|
The temporal patch size of the vision encoder.
|
||||||
merge_size (`int`, *optional*, defaults to 2):
|
merge_size (`int`, *optional*, defaults to 2):
|
||||||
The merge size of the vision encoder to llm encoder.
|
The merge size of the vision encoder to llm encoder.
|
||||||
|
min_frames (`int`, *optional*, defaults to 4):
|
||||||
|
The minimum number of frames that can be sampled.
|
||||||
|
max_frames (`int`, *optional*, defaults to 768):
|
||||||
|
The maximum number of frames that can be sampled.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@requires(backends=("torchvision",))
|
@requires(backends=("torchvision",))
|
||||||
class Qwen2VLVideoProcessor(BaseVideoProcessor):
|
class Qwen2VLVideoProcessor(BaseVideoProcessor):
|
||||||
resample = PILImageResampling.BICUBIC
|
resample = PILImageResampling.BICUBIC
|
||||||
size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
|
size = {"shortest_edge": 128 * 28 * 28, "longest_edge": 28 * 28 * 768}
|
||||||
image_mean = OPENAI_CLIP_MEAN
|
image_mean = OPENAI_CLIP_MEAN
|
||||||
image_std = OPENAI_CLIP_STD
|
image_std = OPENAI_CLIP_STD
|
||||||
do_resize = True
|
do_resize = True
|
||||||
do_rescale = True
|
do_rescale = True
|
||||||
do_normalize = True
|
do_normalize = True
|
||||||
do_convert_rgb = True
|
do_convert_rgb = True
|
||||||
min_pixels = 56 * 56
|
min_pixels = 128 * 28 * 28
|
||||||
max_pixels = 28 * 28 * 1280
|
max_pixels = 28 * 28 * 768
|
||||||
patch_size = 14
|
patch_size = 14
|
||||||
temporal_patch_size = 2
|
temporal_patch_size = 2
|
||||||
merge_size = 2
|
merge_size = 2
|
||||||
|
min_frames = 4
|
||||||
|
max_frames = 768
|
||||||
|
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
|
||||||
valid_kwargs = Qwen2VLVideoProcessorInitKwargs
|
valid_kwargs = Qwen2VLVideoProcessorInitKwargs
|
||||||
model_input_names = ["pixel_values_videos", "video_grid_thw"]
|
model_input_names = ["pixel_values_videos", "video_grid_thw"]
|
||||||
|
|
||||||
@ -109,9 +119,80 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.size = {"shortest_edge": self.min_pixels, "longest_edge": self.max_pixels}
|
self.size = {"shortest_edge": self.min_pixels, "longest_edge": self.max_pixels}
|
||||||
|
|
||||||
|
def sample_frames(
|
||||||
|
self,
|
||||||
|
video: "torch.Tensor",
|
||||||
|
frame_factor: int,
|
||||||
|
min_frames: int,
|
||||||
|
max_frames: int,
|
||||||
|
metadata: Optional[Union[VideoMetadata, dict]] = None,
|
||||||
|
num_frames: Optional[int] = None,
|
||||||
|
fps: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames.
|
||||||
|
If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames`
|
||||||
|
and `fps` are mutually exclusive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video (`torch.Tensor`):
|
||||||
|
Video that need to be sampled.
|
||||||
|
frame_factor (`int`):
|
||||||
|
The temporal patch size of the vision encoder. Number of sampled frames will be rounded to be divisible by frame factor.
|
||||||
|
min_frames (`int`):
|
||||||
|
The minimum number of frames that can be sampled.
|
||||||
|
max_frames (`int`):
|
||||||
|
The maximum number of frames that can be sampled.
|
||||||
|
metadata (`VideoMetadata`, *optional*):
|
||||||
|
Metadata of the video containing information about total duration, fps and total number of frames.
|
||||||
|
num_frames (`int`, *optional*):
|
||||||
|
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
||||||
|
fps (`int`, *optional*):
|
||||||
|
Target frames to sample per second. Defaults to `self.fps`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor:
|
||||||
|
Sampled video frames.
|
||||||
|
"""
|
||||||
|
if fps is not None and num_frames is not None:
|
||||||
|
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
|
||||||
|
|
||||||
|
num_frames = num_frames if num_frames is not None else self.num_frames
|
||||||
|
fps = fps if fps is not None else self.fps
|
||||||
|
total_num_frames = video.shape[0]
|
||||||
|
|
||||||
|
# If num_frames is not given but fps is, calculate num_frames from fps
|
||||||
|
if num_frames is not None:
|
||||||
|
num_frames = round(num_frames / frame_factor) * frame_factor
|
||||||
|
elif fps is not None:
|
||||||
|
if metadata is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
|
||||||
|
"Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
|
||||||
|
)
|
||||||
|
max_frames = math.floor(min(max_frames, total_num_frames) / frame_factor) * frame_factor
|
||||||
|
num_frames = total_num_frames / metadata["fps"] * fps
|
||||||
|
num_frames = min(min(max(num_frames, min_frames), max_frames), total_num_frames)
|
||||||
|
num_frames = math.floor(num_frames / frame_factor) * frame_factor
|
||||||
|
|
||||||
|
if num_frames > total_num_frames:
|
||||||
|
raise ValueError(
|
||||||
|
f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
|
||||||
|
"Decrease `num_frames` or `fps` for sampling."
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_frames is not None:
|
||||||
|
indices = torch.arange(0, total_num_frames, total_num_frames / num_frames).int()
|
||||||
|
else:
|
||||||
|
indices = torch.arange(0, total_num_frames).int()
|
||||||
|
video = video[indices].contiguous()
|
||||||
|
|
||||||
|
return video
|
||||||
|
|
||||||
def _preprocess(
|
def _preprocess(
|
||||||
self,
|
self,
|
||||||
videos: List["torch.Tensor"],
|
videos: List["torch.Tensor"],
|
||||||
|
video_metadata: Union[List[VideoMetadata], List[dict]],
|
||||||
do_convert_rgb: bool,
|
do_convert_rgb: bool,
|
||||||
do_resize: bool,
|
do_resize: bool,
|
||||||
size: SizeDict,
|
size: SizeDict,
|
||||||
@ -119,6 +200,7 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
|
|||||||
do_rescale: bool,
|
do_rescale: bool,
|
||||||
rescale_factor: float,
|
rescale_factor: float,
|
||||||
do_normalize: bool,
|
do_normalize: bool,
|
||||||
|
do_sample_frames: bool,
|
||||||
image_mean: Optional[Union[float, List[float]]],
|
image_mean: Optional[Union[float, List[float]]],
|
||||||
image_std: Optional[Union[float, List[float]]],
|
image_std: Optional[Union[float, List[float]]],
|
||||||
min_pixels: Optional[int] = None,
|
min_pixels: Optional[int] = None,
|
||||||
@ -126,9 +208,28 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
|
|||||||
patch_size: Optional[int] = None,
|
patch_size: Optional[int] = None,
|
||||||
temporal_patch_size: Optional[int] = None,
|
temporal_patch_size: Optional[int] = None,
|
||||||
merge_size: Optional[int] = None,
|
merge_size: Optional[int] = None,
|
||||||
|
fps: Optional[int] = None,
|
||||||
|
num_frames: Optional[int] = None,
|
||||||
|
min_frames: Optional[int] = None,
|
||||||
|
max_frames: Optional[int] = None,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if do_sample_frames:
|
||||||
|
# Sample video frames
|
||||||
|
videos = [
|
||||||
|
self.sample_frames(
|
||||||
|
video,
|
||||||
|
frame_factor=temporal_patch_size,
|
||||||
|
min_frames=min_frames,
|
||||||
|
max_frames=max_frames,
|
||||||
|
metadata=metadata,
|
||||||
|
num_frames=num_frames,
|
||||||
|
fps=fps,
|
||||||
|
)
|
||||||
|
for video, metadata in zip(videos, video_metadata)
|
||||||
|
]
|
||||||
|
|
||||||
# Group videos by size for batched resizing
|
# Group videos by size for batched resizing
|
||||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||||
resized_videos_grouped = {}
|
resized_videos_grouped = {}
|
||||||
|
@ -16,18 +16,15 @@
|
|||||||
Processor class for SmolVLM.
|
Processor class for SmolVLM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput, make_nested_list_of_images
|
from ...image_utils import ImageInput, make_nested_list_of_images
|
||||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
from ...processing_utils import AllKwargsForChatTemplate, ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||||
from ...tokenization_utils_base import BatchEncoding, TextInput
|
from ...tokenization_utils_base import BatchEncoding, TextInput
|
||||||
from ...utils import is_num2words_available, is_vision_available, logging
|
from ...utils import is_num2words_available, is_vision_available, logging
|
||||||
from ...video_utils import VideoInput, load_video, make_batched_videos
|
from ...video_utils import VideoInput
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@ -35,7 +32,13 @@ if is_vision_available():
|
|||||||
DEFAULT_MEDIA_OUTTRO,
|
DEFAULT_MEDIA_OUTTRO,
|
||||||
DEFAULT_VIDEO_INTRO,
|
DEFAULT_VIDEO_INTRO,
|
||||||
FRAME_TIMESTAMP_MESSAGE,
|
FRAME_TIMESTAMP_MESSAGE,
|
||||||
smolvlm_sample_indices_fn,
|
)
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from .video_processing_smolvlm import (
|
||||||
|
DEFAULT_MEDIA_OUTTRO,
|
||||||
|
DEFAULT_VIDEO_INTRO,
|
||||||
|
FRAME_TIMESTAMP_MESSAGE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -50,6 +53,10 @@ else:
|
|||||||
num2words = None
|
num2words = None
|
||||||
|
|
||||||
|
|
||||||
|
# The correct chat template to be used for videos after #38105
|
||||||
|
DEFAULT_CHAT_TEMPLATE = "<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% elif line['type'] == 'video' %}{{ '<video>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
||||||
|
|
||||||
|
|
||||||
def _prompt_split_image(
|
def _prompt_split_image(
|
||||||
image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_image_token
|
image_seq_len, image_rows, image_cols, fake_token_around_image, image_token, global_image_token
|
||||||
):
|
):
|
||||||
@ -140,9 +147,7 @@ class SmolVLMProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
attributes = ["image_processor", "tokenizer", "video_processor"]
|
attributes = ["image_processor", "tokenizer", "video_processor"]
|
||||||
image_processor_class = "SmolVLMImageProcessor"
|
image_processor_class = "SmolVLMImageProcessor"
|
||||||
video_processor_class = (
|
video_processor_class = "SmolVLMVideoProcessor" # NOTE: uses different interpolation than slow processors
|
||||||
"SmolVLMImageProcessor" # TODO: raushan should be VideoProcessor when LANCZOS resizing is settled
|
|
||||||
)
|
|
||||||
tokenizer_class = "AutoTokenizer"
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -160,17 +165,7 @@ class SmolVLMProcessor(ProcessorMixin):
|
|||||||
self.end_of_utterance_token = getattr(tokenizer, "end_of_utterance_token", "<end_of_utterance>")
|
self.end_of_utterance_token = getattr(tokenizer, "end_of_utterance_token", "<end_of_utterance>")
|
||||||
self.global_image_token = getattr(tokenizer, "global_image_token", "<global-img>")
|
self.global_image_token = getattr(tokenizer, "global_image_token", "<global-img>")
|
||||||
self.image_seq_len = image_seq_len
|
self.image_seq_len = image_seq_len
|
||||||
|
self.video_token = getattr(tokenizer, "video_token", "<video>")
|
||||||
self.video_size = video_processor.video_sampling["video_size"]
|
|
||||||
self.image_size = image_processor.size
|
|
||||||
|
|
||||||
self.do_image_splitting = image_processor.do_image_splitting
|
|
||||||
self.do_video_splitting = video_processor.video_sampling.get("do_image_splitting", False)
|
|
||||||
|
|
||||||
self.default_max_frames = video_processor.video_sampling["max_frames"]
|
|
||||||
self.default_fps = video_processor.video_sampling["fps"]
|
|
||||||
# Matches one or more occurrences of <row_x_col_y> tags (where x and y are digits, optionally surrounded by newline characters
|
|
||||||
# self._regex_to_remove_extra_special_tokens = re.compile(r"(<row_\d+_col_\d+>\n?)+")
|
|
||||||
|
|
||||||
if not num2words:
|
if not num2words:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@ -179,16 +174,12 @@ class SmolVLMProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
|
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
|
||||||
|
|
||||||
def process_vision(
|
def process_vision(self, text, images, output_kwargs):
|
||||||
self, text, images, output_kwargs, do_image_splitting=False, image_processor_size=None, processor=None
|
|
||||||
):
|
|
||||||
if text is not None:
|
if text is not None:
|
||||||
n_images_in_text = [sample.count(self.image_token) for sample in text]
|
n_images_in_text = [sample.count(self.image_token) for sample in text]
|
||||||
|
|
||||||
n_images_in_images = [len(sublist) for sublist in images]
|
n_images_in_images = [len(sublist) for sublist in images]
|
||||||
image_inputs = processor(
|
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||||
images, do_image_splitting=do_image_splitting, size=image_processor_size, **output_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if text is None:
|
if text is None:
|
||||||
return None, image_inputs
|
return None, image_inputs
|
||||||
@ -227,6 +218,50 @@ class SmolVLMProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
return prompt_strings, image_inputs
|
return prompt_strings, image_inputs
|
||||||
|
|
||||||
|
def process_video(self, text, videos, output_kwargs):
|
||||||
|
if text is not None:
|
||||||
|
n_videos_in_text = [sample.count(self.video_token) for sample in text]
|
||||||
|
|
||||||
|
n_videos_in_videos = [len(sublist) for sublist in videos]
|
||||||
|
video_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
|
||||||
|
|
||||||
|
num_frames = video_inputs["pixel_values"].shape[1]
|
||||||
|
batch_timestamps = iter(video_inputs.pop("timestamps"))
|
||||||
|
batch_durations = iter(video_inputs.pop("durations"))
|
||||||
|
|
||||||
|
if text is None:
|
||||||
|
return None, video_inputs
|
||||||
|
|
||||||
|
if n_videos_in_videos != n_videos_in_text:
|
||||||
|
raise ValueError(
|
||||||
|
f"The number of videos in the text {n_videos_in_text} and videos {n_videos_in_videos} should be the same."
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_strings = []
|
||||||
|
for sample in text:
|
||||||
|
while self.video_token in sample:
|
||||||
|
timestamps = next(batch_timestamps)
|
||||||
|
duration = next(batch_durations)
|
||||||
|
duration_td = timedelta(seconds=int(duration))
|
||||||
|
image_prompt_strings = DEFAULT_VIDEO_INTRO.format(
|
||||||
|
frame_count=num2words(num_frames), video_duration=str(duration_td)
|
||||||
|
)
|
||||||
|
for timestamp in timestamps:
|
||||||
|
image_prompt_string = _prompt_single_image(
|
||||||
|
self.image_seq_len,
|
||||||
|
image_token=self.image_token,
|
||||||
|
fake_token_around_image=self.fake_image_token,
|
||||||
|
global_image_token=self.global_image_token,
|
||||||
|
)
|
||||||
|
timestamp = f"{timestamp[0]:02d}:{timestamp[1]:02d}"
|
||||||
|
image_prompt_string = FRAME_TIMESTAMP_MESSAGE.format(timestamp=timestamp) + image_prompt_string
|
||||||
|
image_prompt_strings += image_prompt_string
|
||||||
|
|
||||||
|
image_prompt_strings += DEFAULT_MEDIA_OUTTRO
|
||||||
|
sample = sample.replace(self.video_token, image_prompt_strings, 1)
|
||||||
|
prompt_strings.append(sample)
|
||||||
|
return prompt_strings, video_inputs
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
images: Union[ImageInput, List[ImageInput], List[List[ImageInput]]] = None,
|
images: Union[ImageInput, List[ImageInput], List[List[ImageInput]]] = None,
|
||||||
@ -310,21 +345,14 @@ class SmolVLMProcessor(ProcessorMixin):
|
|||||||
text, vision_inputs = self.process_vision(
|
text, vision_inputs = self.process_vision(
|
||||||
text,
|
text,
|
||||||
images,
|
images,
|
||||||
output_kwargs["images_kwargs"],
|
output_kwargs,
|
||||||
do_image_splitting=self.do_image_splitting,
|
|
||||||
image_processor_size=self.image_size,
|
|
||||||
processor=self.image_processor,
|
|
||||||
)
|
)
|
||||||
inputs.update(vision_inputs)
|
inputs.update(vision_inputs)
|
||||||
elif videos is not None:
|
elif videos is not None:
|
||||||
videos = make_batched_videos(videos)
|
text, vision_inputs = self.process_video(
|
||||||
text, vision_inputs = self.process_vision(
|
|
||||||
text,
|
text,
|
||||||
videos,
|
videos,
|
||||||
output_kwargs["videos_kwargs"],
|
output_kwargs,
|
||||||
do_image_splitting=self.do_image_splitting,
|
|
||||||
image_processor_size=self.video_size,
|
|
||||||
processor=self.video_processor,
|
|
||||||
)
|
)
|
||||||
inputs.update(vision_inputs)
|
inputs.update(vision_inputs)
|
||||||
|
|
||||||
@ -337,93 +365,6 @@ class SmolVLMProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
return BatchFeature(inputs, tensor_type=return_tensors)
|
return BatchFeature(inputs, tensor_type=return_tensors)
|
||||||
|
|
||||||
def _process_messages_for_chat_template(
|
|
||||||
self,
|
|
||||||
conversations: List[List[Dict[str, str]]],
|
|
||||||
batch_images: List[ImageInput],
|
|
||||||
batch_videos: List[VideoInput],
|
|
||||||
batch_video_metadata: List[List[Dict[str, any]]],
|
|
||||||
**chat_template_kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Used within `apply_chat_template` when a model has special way to process conversation history. For example,
|
|
||||||
video models might want to specify in the prompt the duration of video or which frame indices at which timestamps
|
|
||||||
were sampled. This information cannot be accessed before the video is loaded.
|
|
||||||
For most models it is a no-op, must be overridden by model processors which require special processing.
|
|
||||||
Args:
|
|
||||||
conversation (`List[Dict, str, str]`):
|
|
||||||
The conversation to process. Always comes in batched format.
|
|
||||||
batch_images (`List[List[ImageInput]]`):
|
|
||||||
Batch of images that were loaded from url/path defined in the conversation. The images
|
|
||||||
are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL` images
|
|
||||||
per batch.
|
|
||||||
batch_videos (`List[List[ImageInput]]`):
|
|
||||||
Batch of videos that were loaded from url/path defined in the conversation. The videos
|
|
||||||
are ordered in the same way as in the conversation. Comes in nested list format, one list of 4D video arrays
|
|
||||||
per batch.
|
|
||||||
batch_video_metadata (`List[List[Dict[[str, any]]]]`):
|
|
||||||
Batch of metadata returned from loading videos. That includes video fps, duration and total number of framer in original video.
|
|
||||||
Metadata are ordered in the same way as `batch_videos`. Comes in nested list format, one list of 4D video arrays
|
|
||||||
per batch.
|
|
||||||
"""
|
|
||||||
# We don't want to modify in-place the messages passed by user
|
|
||||||
# The user might want to add new turn on conv and continue generation
|
|
||||||
conversations = copy.deepcopy(conversations)
|
|
||||||
batch_num_frames, batch_timestamps = [], []
|
|
||||||
for metadata_list, video_list in zip(batch_video_metadata, batch_videos):
|
|
||||||
for metadata, video in zip(metadata_list, video_list):
|
|
||||||
duration_sec = getattr(metadata, "duration")
|
|
||||||
frames_idx = getattr(metadata, "frames_indices")
|
|
||||||
fps = getattr(metadata, "fps")
|
|
||||||
|
|
||||||
timestamps = []
|
|
||||||
for idx, frame_np in zip(frames_idx, video):
|
|
||||||
sec = idx / fps
|
|
||||||
mm = int(sec // 60)
|
|
||||||
ss = int(sec % 60)
|
|
||||||
timestamps.append(f"{mm:02d}:{ss:02d}")
|
|
||||||
batch_timestamps.append(timestamps)
|
|
||||||
batch_num_frames.append(len(video))
|
|
||||||
|
|
||||||
for conversation in conversations:
|
|
||||||
# For each message, scan content for {"type": "video"}
|
|
||||||
for msg in conversation:
|
|
||||||
if "content" not in msg:
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_content = []
|
|
||||||
for block in msg["content"]:
|
|
||||||
if block.get("type") == "video":
|
|
||||||
curr_timestamps = batch_timestamps.pop(0)
|
|
||||||
curr_num_frames = batch_num_frames.pop(0)
|
|
||||||
|
|
||||||
# Build the video intro texts
|
|
||||||
td = timedelta(seconds=int(duration_sec))
|
|
||||||
new_content.append(
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": DEFAULT_VIDEO_INTRO.format(
|
|
||||||
frame_count=num2words(curr_num_frames), video_duration=str(td)
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) Insert per-frame lines: "Frame from {timestamp}:", then an "image" block
|
|
||||||
for i, ts in enumerate(curr_timestamps):
|
|
||||||
new_content.append({"type": "text", "text": FRAME_TIMESTAMP_MESSAGE.format(timestamp=ts)})
|
|
||||||
new_content.append({"type": "image"})
|
|
||||||
|
|
||||||
# 3) Optionally add an outro (e.g. "Now answer the question:")
|
|
||||||
new_content.append({"type": "text", "text": DEFAULT_MEDIA_OUTTRO})
|
|
||||||
# Do NOT add the original block => we skip it (since we've replaced it)
|
|
||||||
else:
|
|
||||||
# keep original block
|
|
||||||
new_content.append(block)
|
|
||||||
|
|
||||||
# update the content
|
|
||||||
msg["content"] = new_content
|
|
||||||
return conversations
|
|
||||||
|
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
This method forwards all its arguments to SmolVLMTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
This method forwards all its arguments to SmolVLMTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
@ -446,45 +387,54 @@ class SmolVLMProcessor(ProcessorMixin):
|
|||||||
image_processor_input_names = self.image_processor.model_input_names
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
return list(dict.fromkeys(image_processor_input_names + tokenizer_input_names))
|
return list(dict.fromkeys(image_processor_input_names + tokenizer_input_names))
|
||||||
|
|
||||||
# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
|
def apply_chat_template(
|
||||||
def _load_video_for_model(
|
|
||||||
self,
|
self,
|
||||||
video: Union[str, "VideoInput"],
|
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
|
||||||
num_frames: Optional[int] = None,
|
chat_template: Optional[str] = None,
|
||||||
fps: Optional[int] = None,
|
**kwargs: Unpack[AllKwargsForChatTemplate],
|
||||||
backend: str = "opencv",
|
) -> str:
|
||||||
skip_secs: int = 0.0,
|
|
||||||
**kwargs,
|
|
||||||
) -> np.array:
|
|
||||||
"""
|
"""
|
||||||
Loads `video` to a numpy array.
|
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
|
||||||
|
conversations to turn them into a single tokenizable string.
|
||||||
|
|
||||||
|
The input is expected to be in the following format, where each message content is a list consisting of text and
|
||||||
|
optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
|
||||||
|
`pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.
|
||||||
|
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||||
|
{"type": "text", "text": "Please describe this image in detail."},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
video (`str` or `VideoInput`):
|
conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`):
|
||||||
The video to convert to the numpy array format. Can be a link to video or local path.
|
The conversation to format.
|
||||||
num_frames (`int`, *optional*):
|
chat_template (`Optional[str]`, *optional*):
|
||||||
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
|
||||||
fps (`int`, *optional*):
|
chat template is used.
|
||||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
|
||||||
If not specified and `num_frames==None`, all frames are sampled.
|
|
||||||
backend (`str`, *optional*, defaults to `"opencv"`):
|
|
||||||
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[`np.array`, Dict]: A tuple containing:
|
|
||||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
|
||||||
- Metadata dictionary.
|
|
||||||
"""
|
"""
|
||||||
max_frames = self.default_max_frames if num_frames is None else num_frames
|
if isinstance(conversation, (list, tuple)) and (
|
||||||
target_fps = self.default_fps if fps is None else fps
|
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
|
||||||
|
):
|
||||||
|
conversations = conversation
|
||||||
|
else:
|
||||||
|
conversations = [conversation]
|
||||||
|
|
||||||
def sample_indices_fn_func(metadata, **fn_kwargs):
|
has_video = any(
|
||||||
return smolvlm_sample_indices_fn(
|
(isinstance(content, dict) and content["type"] == "video")
|
||||||
metadata, max_frames=max_frames, target_fps=target_fps, skip_secs=skip_secs, **fn_kwargs
|
for conversation in conversations
|
||||||
)
|
for message in conversation
|
||||||
|
for content in message["content"]
|
||||||
video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func)
|
)
|
||||||
return video, metadata
|
if chat_template is None and has_video:
|
||||||
|
# re-assign to the correct default template for BC, if user is not requesting their own template
|
||||||
|
chat_template = DEFAULT_CHAT_TEMPLATE
|
||||||
|
return super().apply_chat_template(conversation, chat_template, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["SmolVLMProcessor"]
|
__all__ = ["SmolVLMProcessor"]
|
||||||
|
@ -13,13 +13,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...image_processing_utils import (
|
from ...image_processing_utils import (
|
||||||
BatchFeature,
|
BatchFeature,
|
||||||
|
get_size_dict,
|
||||||
)
|
)
|
||||||
from ...image_utils import (
|
from ...image_utils import (
|
||||||
IMAGENET_STANDARD_MEAN,
|
IMAGENET_STANDARD_MEAN,
|
||||||
@ -38,7 +38,7 @@ from ...utils.import_utils import requires
|
|||||||
from ...video_processing_utils import (
|
from ...video_processing_utils import (
|
||||||
BaseVideoProcessor,
|
BaseVideoProcessor,
|
||||||
)
|
)
|
||||||
from ...video_utils import group_videos_by_shape, reorder_videos
|
from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@ -68,66 +68,6 @@ FRAME_TIMESTAMP_MESSAGE = "\nFrame from {timestamp}:"
|
|||||||
MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
|
MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
|
||||||
|
|
||||||
|
|
||||||
def smolvlm_sample_indices_fn(metadata, max_frames, target_fps, skip_secs=0):
|
|
||||||
"""
|
|
||||||
Example sampling function which:
|
|
||||||
- Uses `max_frames` (if provided) or calculates it from `fps` and metadata.
|
|
||||||
- Applies a basic center-skip if fewer frames than available, otherwise
|
|
||||||
optionally skips `skip_secs` from both the start and end.
|
|
||||||
- Uniformly samples the desired number of frames between the start and end indices.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_frames (`int`):
|
|
||||||
Maximum number of frames to sample.
|
|
||||||
target_fps (`int`):
|
|
||||||
Target frames to sample per second.
|
|
||||||
metadata (`dict`):
|
|
||||||
Contains video metadata such as "n_frames" and "video_fps".
|
|
||||||
skip_secs (`float`, *optional*, defaults to 1.0):
|
|
||||||
Number of seconds to skip from the start and end if the video is long enough.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
numpy.ndarray:
|
|
||||||
An array of unique frame indices to sample.
|
|
||||||
"""
|
|
||||||
|
|
||||||
total_num_frames = getattr(metadata, "total_num_frames", 0)
|
|
||||||
if total_num_frames <= 0:
|
|
||||||
raise ValueError(f"Invalid total_num_frames={total_num_frames} in metadata.")
|
|
||||||
|
|
||||||
native_fps = getattr(metadata, "fps", 30.0)
|
|
||||||
duration_seconds = getattr(metadata, "duration", 0)
|
|
||||||
|
|
||||||
if duration_seconds <= 0:
|
|
||||||
raise ValueError(f"Invalid duration_seconds={duration_seconds} in metadata.")
|
|
||||||
|
|
||||||
# Step 1) Estimate how many frames we'd sample at `target_fps`, fallback if target_fps <= 0
|
|
||||||
estimated_frames = int(round(target_fps * duration_seconds))
|
|
||||||
|
|
||||||
# Step 2) desired_frames
|
|
||||||
desired_frames = min(estimated_frames, max_frames)
|
|
||||||
if desired_frames < 1:
|
|
||||||
desired_frames = 1
|
|
||||||
|
|
||||||
# Step 3) center skip logic
|
|
||||||
start_idx = 0
|
|
||||||
end_idx = total_num_frames - 1
|
|
||||||
|
|
||||||
if skip_secs > 0 and (duration_seconds - 2 * skip_secs) > (max_frames * target_fps):
|
|
||||||
start_idx = int(skip_secs * native_fps)
|
|
||||||
end_idx = int(total_num_frames - skip_secs * native_fps)
|
|
||||||
|
|
||||||
start_idx = max(0, start_idx)
|
|
||||||
end_idx = min(end_idx, total_num_frames - 1)
|
|
||||||
if start_idx >= end_idx:
|
|
||||||
start_idx, end_idx = 0, total_num_frames - 1
|
|
||||||
|
|
||||||
indices = np.linspace(start_idx, end_idx, desired_frames, dtype=int)
|
|
||||||
indices = np.unique(indices)
|
|
||||||
|
|
||||||
return indices
|
|
||||||
|
|
||||||
|
|
||||||
def get_max_height_width(videos: list["torch.Tensor"]) -> List[int]:
|
def get_max_height_width(videos: list["torch.Tensor"]) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Get the maximum height and width across all videos in a batch.
|
Get the maximum height and width across all videos in a batch.
|
||||||
@ -180,13 +120,15 @@ def get_resize_output_image_size(
|
|||||||
return height, width
|
return height, width
|
||||||
|
|
||||||
|
|
||||||
class SmolVLMVideoProcessorInitKwargs(VideosKwargs): ...
|
class SmolVLMVideoProcessorInitKwargs(VideosKwargs):
|
||||||
|
max_image_size: dict[str, int] = None
|
||||||
|
|
||||||
|
|
||||||
@requires(backends=("torchvision",))
|
@requires(backends=("torchvision",))
|
||||||
class SmolVLMVideoProcessor(BaseVideoProcessor):
|
class SmolVLMVideoProcessor(BaseVideoProcessor):
|
||||||
resample = PILImageResampling.LANCZOS
|
resample = PILImageResampling.LANCZOS
|
||||||
size = {"longest_edge": 4 * 364}
|
size = {"longest_edge": 4 * 364}
|
||||||
|
max_image_size = {"longest_edge": 364}
|
||||||
image_mean = IMAGENET_STANDARD_MEAN
|
image_mean = IMAGENET_STANDARD_MEAN
|
||||||
image_std = IMAGENET_STANDARD_STD
|
image_std = IMAGENET_STANDARD_STD
|
||||||
do_resize = True
|
do_resize = True
|
||||||
@ -194,11 +136,21 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
|
|||||||
do_normalize = True
|
do_normalize = True
|
||||||
do_convert_rgb = True
|
do_convert_rgb = True
|
||||||
do_pad = True
|
do_pad = True
|
||||||
|
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
|
||||||
valid_kwargs = SmolVLMVideoProcessorInitKwargs
|
valid_kwargs = SmolVLMVideoProcessorInitKwargs
|
||||||
model_input_names = ["pixel_values", "pixel_attention_mask"]
|
model_input_names = ["pixel_values", "pixel_attention_mask"]
|
||||||
|
|
||||||
def __init__(self, **kwargs: Unpack[SmolVLMVideoProcessorInitKwargs]):
|
def __init__(self, **kwargs: Unpack[SmolVLMVideoProcessorInitKwargs]):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
# For BC pop values from `config.video_sampling`. In official config `video_sampling` is guaranteed to be present
|
||||||
|
# We check for `Noneness` only for certain tests such as `test_init_without_params`
|
||||||
|
if "size" in kwargs and "video_sampling" in kwargs:
|
||||||
|
kwargs["video_sampling"]["video_size"] = kwargs["size"]
|
||||||
|
|
||||||
|
if "video_sampling" in kwargs:
|
||||||
|
self.num_frames = kwargs["video_sampling"]["max_frames"]
|
||||||
|
self.fps = kwargs["video_sampling"]["fps"]
|
||||||
|
self.size = get_size_dict(kwargs["video_sampling"]["video_size"], default_to_square=self.default_to_square)
|
||||||
|
|
||||||
def resize(
|
def resize(
|
||||||
self,
|
self,
|
||||||
@ -240,12 +192,20 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
|
|||||||
new_size = (size.height, size.width)
|
new_size = (size.height, size.width)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Size must contain 'height' and 'width' keys, or 'longest_edge' key. Got {size}.")
|
raise ValueError(f"Size must contain 'height' and 'width' keys, or 'longest_edge' key. Got {size}.")
|
||||||
return F.resize(video, new_size, interpolation=interpolation, antialias=antialias)
|
|
||||||
|
video = F.resize(video, new_size, interpolation=interpolation, antialias=antialias)
|
||||||
|
|
||||||
|
# Resize again to match image processor when `do_image_splitting=False`. Frames have to be squared to `max_image_size`
|
||||||
|
# NOTE: videos are always processoed without image splitting
|
||||||
|
max_size = self.max_image_size["longest_edge"], self.max_image_size["longest_edge"]
|
||||||
|
video = F.resize(video, max_size, interpolation=interpolation, antialias=antialias)
|
||||||
|
return video
|
||||||
|
|
||||||
def pad(
|
def pad(
|
||||||
self,
|
self,
|
||||||
video: "torch.Tensor",
|
video: "torch.Tensor",
|
||||||
padded_size: tuple[int, int],
|
padded_size: tuple[int, int],
|
||||||
|
max_num_frames: int,
|
||||||
fill: int = 0,
|
fill: int = 0,
|
||||||
return_pixel_mask: bool = True,
|
return_pixel_mask: bool = True,
|
||||||
):
|
):
|
||||||
@ -255,24 +215,28 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
|
|||||||
Video to pad.
|
Video to pad.
|
||||||
padded_size (`Tuple[int, int]`):
|
padded_size (`Tuple[int, int]`):
|
||||||
Height and width to pad.
|
Height and width to pad.
|
||||||
|
max_num_frames (`int`):
|
||||||
|
The maximum number of frames to which video will be padded.
|
||||||
fill (`int`, *optional*):
|
fill (`int`, *optional*):
|
||||||
The value to use for the padding.
|
The value to use for the padding.
|
||||||
return_pixel_mask (`bool`, *optional*, defaults to `True`):
|
return_pixel_mask (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to return a pixel mask.
|
Whether to return a pixel mask.
|
||||||
"""
|
"""
|
||||||
original_size = video.size()[-2:]
|
original_size = video.size()[-2:]
|
||||||
padding_bottom = padded_size[0] - original_size[0]
|
padding_height = padded_size[0] - original_size[0]
|
||||||
padding_right = padded_size[1] - original_size[1]
|
padding_width = padded_size[1] - original_size[1]
|
||||||
if padding_bottom < 0 or padding_right < 0:
|
padding_frame = max_num_frames - video.shape[0]
|
||||||
|
if padding_width < 0 or padding_height < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
|
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
|
||||||
f"original size. Got padded size: {padded_size}, original size: {original_size}."
|
f"original size. Got padded size: {padded_size}, original size: {original_size}."
|
||||||
)
|
)
|
||||||
if original_size != padded_size:
|
if original_size != padded_size:
|
||||||
padding = [0, 0, padding_right, padding_bottom]
|
padding = [0, padding_width, 0, padding_height, 0, 0, 0, padding_frame]
|
||||||
video = F.pad(video, padding, fill=fill)
|
video = F.pad(video, padding, fill=fill)
|
||||||
|
|
||||||
# Make a pixel mask for the video, where 1 indicates a valid pixel and 0 indicates padding.
|
# Make a pixel mask for the video, where 1 indicates a valid pixel and 0 indicates padding.
|
||||||
|
# Mask shape is (num_frames, height, width) so we omit the channel dim
|
||||||
pixel_mask = None
|
pixel_mask = None
|
||||||
if return_pixel_mask:
|
if return_pixel_mask:
|
||||||
pixel_mask = torch.zeros_like(video[..., 0, :, :], dtype=torch.int64)
|
pixel_mask = torch.zeros_like(video[..., 0, :, :], dtype=torch.int64)
|
||||||
@ -280,9 +244,79 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
|
|||||||
|
|
||||||
return video, pixel_mask
|
return video, pixel_mask
|
||||||
|
|
||||||
|
def sample_frames(
|
||||||
|
self,
|
||||||
|
video: "torch.Tensor",
|
||||||
|
metadata: Union[VideoMetadata, dict],
|
||||||
|
num_frames: Optional[int] = None,
|
||||||
|
fps: Optional[int] = None,
|
||||||
|
skip_secs: Optional[int] = 1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Video sampling function which:
|
||||||
|
- Uses `num_frames` (if provided) or calculates it from `fps` and metadata.
|
||||||
|
- Applies a basic center-skip if fewer frames than available, otherwise
|
||||||
|
optionally skips `skip_secs` from both the start and end.
|
||||||
|
- Uniformly samples the desired number of frames between the start and end indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video (`torch.Tensor`):
|
||||||
|
Video that need to be sampled.
|
||||||
|
metadata (`VideoMetadata`):
|
||||||
|
Metadata of the video containing information about total duration, fps and total number of frames.
|
||||||
|
num_frames (`int`, *optional*):
|
||||||
|
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
||||||
|
fps (`int`, *optional*):
|
||||||
|
Target frames to sample per second. Defaults to `self.fps`.
|
||||||
|
skip_secs (`float`, *optional*, defaults to `1`):
|
||||||
|
Number of seconds to skip from the start and end if the video is long enough.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor:
|
||||||
|
Sampled video frames.
|
||||||
|
"""
|
||||||
|
num_frames = num_frames if num_frames is not None else self.num_frames
|
||||||
|
fps = fps if fps is not None else self.fps
|
||||||
|
|
||||||
|
total_num_frames = video.shape[0]
|
||||||
|
|
||||||
|
# Step 1) Estimate how many frames we'd sample at `target_fps`, fallback if target_fps <= 0
|
||||||
|
estimated_frames = int(round(fps * metadata["duration"]))
|
||||||
|
|
||||||
|
# Step 2) desired_frames
|
||||||
|
desired_frames = min(estimated_frames, num_frames)
|
||||||
|
if desired_frames < 1:
|
||||||
|
desired_frames = 1
|
||||||
|
|
||||||
|
# Step 3) center skip logic
|
||||||
|
start_idx = 0
|
||||||
|
end_idx = total_num_frames - 1
|
||||||
|
|
||||||
|
if skip_secs > 0 and (metadata["duration"] - 2 * skip_secs) > (num_frames * fps):
|
||||||
|
start_idx = int(skip_secs * metadata["fps"])
|
||||||
|
end_idx = int(total_num_frames - skip_secs * metadata["fps"])
|
||||||
|
|
||||||
|
start_idx = max(0, start_idx)
|
||||||
|
end_idx = min(end_idx, total_num_frames - 1)
|
||||||
|
if start_idx >= end_idx:
|
||||||
|
start_idx, end_idx = 0, total_num_frames - 1
|
||||||
|
|
||||||
|
indices = np.linspace(start_idx, end_idx, desired_frames, dtype=int)
|
||||||
|
indices = np.unique(indices)
|
||||||
|
video = video[indices].contiguous()
|
||||||
|
|
||||||
|
timestamps = []
|
||||||
|
for idx in indices:
|
||||||
|
sec = idx / metadata["fps"]
|
||||||
|
mm = int(sec // 60)
|
||||||
|
ss = int(sec % 60)
|
||||||
|
timestamps.append([mm, ss])
|
||||||
|
return video, timestamps, int(metadata["duration"])
|
||||||
|
|
||||||
def _preprocess(
|
def _preprocess(
|
||||||
self,
|
self,
|
||||||
videos: List["torch.Tensor"],
|
videos: List["torch.Tensor"],
|
||||||
|
video_metadata: Union[List[VideoMetadata], List[dict]],
|
||||||
do_convert_rgb: bool,
|
do_convert_rgb: bool,
|
||||||
do_resize: bool,
|
do_resize: bool,
|
||||||
size: SizeDict,
|
size: SizeDict,
|
||||||
@ -291,13 +325,38 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
|
|||||||
rescale_factor: float,
|
rescale_factor: float,
|
||||||
do_normalize: bool,
|
do_normalize: bool,
|
||||||
do_pad: bool,
|
do_pad: bool,
|
||||||
|
do_sample_frames: bool,
|
||||||
image_mean: Optional[Union[float, List[float]]],
|
image_mean: Optional[Union[float, List[float]]],
|
||||||
image_std: Optional[Union[float, List[float]]],
|
image_std: Optional[Union[float, List[float]]],
|
||||||
|
fps: Optional[int] = None,
|
||||||
|
num_frames: Optional[int] = None,
|
||||||
|
skip_secs: Optional[int] = 0,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# Group videos by size for batched resizing
|
# Group videos by size for batched resizing
|
||||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
if do_sample_frames:
|
||||||
|
if video_metadata[0] is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Frame sampling is enabled but no video metadata was found. SmolVLM requires metadata to correctly sample frames. "
|
||||||
|
"Please pass in `VideoMetadata` object per each input video or set `do_sample_frames=False`"
|
||||||
|
)
|
||||||
|
processed_videos = []
|
||||||
|
timestamps_list, durations_list = [], []
|
||||||
|
for video, metadata in zip(videos, video_metadata):
|
||||||
|
video, timestamps, duration = self.sample_frames(video, metadata, num_frames, fps, skip_secs)
|
||||||
|
timestamps_list.append(timestamps)
|
||||||
|
durations_list.append(duration)
|
||||||
|
processed_videos.append(video)
|
||||||
|
else:
|
||||||
|
# Assume 24 fps by default and prepare timestamps for the whole video when all frames are sampled
|
||||||
|
processed_videos = videos
|
||||||
|
timestamps_list = [
|
||||||
|
[(int((idx / 24) // 60), int((idx / 24) % 60)) for idx in range(len(video))] for video in videos
|
||||||
|
]
|
||||||
|
durations_list = [len(video) // 24 for video in videos]
|
||||||
|
|
||||||
|
grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos)
|
||||||
resized_videos_grouped = {}
|
resized_videos_grouped = {}
|
||||||
for shape, stacked_videos in grouped_videos.items():
|
for shape, stacked_videos in grouped_videos.items():
|
||||||
if do_convert_rgb:
|
if do_convert_rgb:
|
||||||
@ -319,12 +378,15 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
|
|||||||
|
|
||||||
if do_pad:
|
if do_pad:
|
||||||
pad_size = get_max_height_width(processed_videos)
|
pad_size = get_max_height_width(processed_videos)
|
||||||
|
max_num_frames = max(len(video) for video in processed_videos)
|
||||||
grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos)
|
grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos)
|
||||||
processed_padded_mask_grouped = {}
|
processed_padded_mask_grouped = {}
|
||||||
processed_videos_grouped = {}
|
processed_videos_grouped = {}
|
||||||
|
|
||||||
for shape, stacked_videos in grouped_videos.items():
|
for shape, stacked_videos in grouped_videos.items():
|
||||||
stacked_videos, padded_masks = self.pad(stacked_videos, padded_size=pad_size)
|
stacked_videos, padded_masks = self.pad(
|
||||||
|
stacked_videos, padded_size=pad_size, max_num_frames=max_num_frames
|
||||||
|
)
|
||||||
processed_videos_grouped[shape] = stacked_videos
|
processed_videos_grouped[shape] = stacked_videos
|
||||||
processed_padded_mask_grouped[shape] = padded_masks
|
processed_padded_mask_grouped[shape] = padded_masks
|
||||||
|
|
||||||
@ -332,7 +394,7 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
|
|||||||
pixel_attention_mask = reorder_videos(processed_padded_mask_grouped, grouped_videos_index)
|
pixel_attention_mask = reorder_videos(processed_padded_mask_grouped, grouped_videos_index)
|
||||||
|
|
||||||
processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos
|
processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos
|
||||||
data = {"pixel_values": processed_videos}
|
data = {"pixel_values": processed_videos, "timestamps": timestamps_list, "durations": durations_list}
|
||||||
|
|
||||||
if do_pad:
|
if do_pad:
|
||||||
data["pixel_attention_mask"] = (
|
data["pixel_attention_mask"] = (
|
||||||
|
@ -46,6 +46,7 @@ class VideoLlavaVideoProcessor(BaseVideoProcessor):
|
|||||||
do_rescale = True
|
do_rescale = True
|
||||||
do_normalize = True
|
do_normalize = True
|
||||||
do_convert_rgb = True
|
do_convert_rgb = True
|
||||||
|
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
|
||||||
valid_kwargs = VideoLlavaFastVideoProcessorInitKwargs
|
valid_kwargs = VideoLlavaFastVideoProcessorInitKwargs
|
||||||
model_input_names = ["pixel_values_videos"]
|
model_input_names = ["pixel_values_videos"]
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ import typing
|
|||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, TypedDict, Union
|
from typing import Any, Dict, Optional, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import typing_extensions
|
import typing_extensions
|
||||||
@ -33,9 +33,9 @@ from huggingface_hub.errors import EntryNotFoundError
|
|||||||
from .audio_utils import load_audio
|
from .audio_utils import load_audio
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .feature_extraction_utils import BatchFeature
|
from .feature_extraction_utils import BatchFeature
|
||||||
from .image_utils import ChannelDimension, ImageInput, is_valid_image, is_vision_available, load_image
|
from .image_utils import ChannelDimension, is_valid_image, is_vision_available, load_image
|
||||||
from .utils.chat_template_utils import render_jinja_template
|
from .utils.chat_template_utils import render_jinja_template
|
||||||
from .video_utils import VideoInput, load_video
|
from .video_utils import VideoMetadata, load_video
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@ -64,6 +64,7 @@ from .utils import (
|
|||||||
list_repo_templates,
|
list_repo_templates,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
from .utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@ -235,6 +236,14 @@ class VideosKwargs(TypedDict, total=False):
|
|||||||
Whether to pad the video to the `(max_height, max_width)` of the videos in the batch.
|
Whether to pad the video to the `(max_height, max_width)` of the videos in the batch.
|
||||||
do_center_crop (`bool`, *optional*):
|
do_center_crop (`bool`, *optional*):
|
||||||
Whether to center crop the video.
|
Whether to center crop the video.
|
||||||
|
do_sample_frames (`bool`, *optional*):
|
||||||
|
Whether to sample frames from the video before processing or to process the whole video.
|
||||||
|
video_metadata (`VideoMetadata`, *optional*):
|
||||||
|
Metadata of the video containing information about total duration, fps and total number of frames.
|
||||||
|
num_frames (`int`, *optional*):
|
||||||
|
Maximum number of frames to sample when `do_sample_frames=True`.
|
||||||
|
fps (`int`, *optional*):
|
||||||
|
Target frames to sample per second when `do_sample_frames=True`.
|
||||||
crop_size (`Dict[str, int]`, *optional*):
|
crop_size (`Dict[str, int]`, *optional*):
|
||||||
Desired output size when applying center-cropping.
|
Desired output size when applying center-cropping.
|
||||||
data_format (`ChannelDimension` or `str`, *optional*):
|
data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
@ -260,6 +269,10 @@ class VideosKwargs(TypedDict, total=False):
|
|||||||
data_format: Optional[ChannelDimension]
|
data_format: Optional[ChannelDimension]
|
||||||
input_data_format: Optional[Union[str, ChannelDimension]]
|
input_data_format: Optional[Union[str, ChannelDimension]]
|
||||||
device: Optional[str]
|
device: Optional[str]
|
||||||
|
do_sample_frames: Optional[bool]
|
||||||
|
video_metadata: Optional[Union[VideoMetadata, dict]]
|
||||||
|
fps: Optional[int]
|
||||||
|
num_frames: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
class AudioKwargs(TypedDict, total=False):
|
class AudioKwargs(TypedDict, total=False):
|
||||||
@ -409,9 +422,6 @@ class ChatTemplateLoadKwargs(TypedDict, total=False):
|
|||||||
The backend to use when loading the video which will be used only when there are videos in the conversation.
|
The backend to use when loading the video which will be used only when there are videos in the conversation.
|
||||||
Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav" because it is the only backend
|
Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav" because it is the only backend
|
||||||
that supports all types of sources to load from.
|
that supports all types of sources to load from.
|
||||||
video_fps (`int`, *optional*):
|
|
||||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
|
||||||
If not specified and `num_frames==None`, all frames are sampled.
|
|
||||||
sample_indices_fn (`Callable`, *optional*):
|
sample_indices_fn (`Callable`, *optional*):
|
||||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||||
@ -424,9 +434,7 @@ class ChatTemplateLoadKwargs(TypedDict, total=False):
|
|||||||
return np.linspace(start_idx, end_idx, num_frames, dtype=int)
|
return np.linspace(start_idx, end_idx, num_frames, dtype=int)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_frames: Optional[int] = None
|
|
||||||
video_load_backend: Optional[str] = "pyav"
|
video_load_backend: Optional[str] = "pyav"
|
||||||
video_fps: Optional[int] = None
|
|
||||||
sampling_rate: Optional[int] = 16_000
|
sampling_rate: Optional[int] = 16_000
|
||||||
load_audio_from_video: Optional[bool] = False
|
load_audio_from_video: Optional[bool] = False
|
||||||
|
|
||||||
@ -1371,40 +1379,7 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
)
|
)
|
||||||
return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)}
|
return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)}
|
||||||
|
|
||||||
def _process_messages_for_chat_template(
|
@deprecate_kwarg("video_fps", version="4.58", new_name="fps")
|
||||||
self,
|
|
||||||
conversation: List[List[Dict[str, str]]],
|
|
||||||
batch_images: List[ImageInput],
|
|
||||||
batch_videos: List[VideoInput],
|
|
||||||
batch_video_metadata: List[List[Dict[str, any]]],
|
|
||||||
**mm_load_kwargs: Unpack[ChatTemplateLoadKwargs],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Used within `apply_chat_template` when a model has a special way to process conversation history. For example,
|
|
||||||
video models might want to specify in the prompt the duration of video or which frame indices at which timestamps
|
|
||||||
were sampled. This information cannot be accessed before the video is loaded.
|
|
||||||
|
|
||||||
For most models it is a no-op, and must be overridden by model processors which require special processing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation (`List[Dict, str, str]`):
|
|
||||||
The conversation to process. Always comes in batched format.
|
|
||||||
batch_images (`List[List[ImageInput]]`):
|
|
||||||
Batch of images that were loaded from url/path defined in the conversation. The images
|
|
||||||
are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL` images
|
|
||||||
per batch.
|
|
||||||
batch_videos (`List[List[ImageInput]]`):
|
|
||||||
Batch of videos that were loaded from url/path defined in the conversation. The videos
|
|
||||||
are ordered in the samm way as in the conversation. Comes in nested list format, one list of 4D video arrays
|
|
||||||
per batch.
|
|
||||||
batch_video_metadata (`List[List[Dict[[str, any]]]]`):
|
|
||||||
Batch of metadata returned from loading videos. That includes video fps, duration and total number of framer in original video.
|
|
||||||
Metadata are ordered in the same way as `batch_videos`. Comes in nested list format, one list of 4D video arrays
|
|
||||||
per batch.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return conversation
|
|
||||||
|
|
||||||
def apply_chat_template(
|
def apply_chat_template(
|
||||||
self,
|
self,
|
||||||
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
|
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
|
||||||
@ -1423,7 +1398,7 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||||
{"type": "text", "text": "Please describe this image in detail."},
|
{"type": "text", "text": "Please describe this image in detail."},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
@ -1436,7 +1411,6 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
|
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
|
||||||
chat template is used.
|
chat template is used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if chat_template is None:
|
if chat_template is None:
|
||||||
if isinstance(self.chat_template, dict) and "default" in self.chat_template:
|
if isinstance(self.chat_template, dict) and "default" in self.chat_template:
|
||||||
chat_template = self.chat_template["default"]
|
chat_template = self.chat_template["default"]
|
||||||
@ -1545,16 +1519,12 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
metadata = None
|
metadata = None
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
|
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
|
||||||
"If your model uses this metadata during processing, please load the whole video and let the model sample frames instead."
|
"If your model requires metadata during processing, please load the whole video and let the processor sample frames instead."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# TODO: raushan, should be `self.video_processor.load_video_for_model` when API is added
|
video, metadata = load_video(
|
||||||
video, metadata = self._load_video_for_model(
|
|
||||||
fname,
|
fname,
|
||||||
num_frames=mm_load_kwargs.get("num_frames", None),
|
|
||||||
fps=mm_load_kwargs.get("video_fps", None),
|
|
||||||
backend=mm_load_kwargs["video_load_backend"],
|
backend=mm_load_kwargs["video_load_backend"],
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
videos.append(video)
|
videos.append(video)
|
||||||
video_metadata.append(metadata)
|
video_metadata.append(metadata)
|
||||||
@ -1567,15 +1537,6 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
batch_videos.append(videos)
|
batch_videos.append(videos)
|
||||||
batch_video_metadata.append(video_metadata)
|
batch_video_metadata.append(video_metadata)
|
||||||
|
|
||||||
# Process conversation with video/image information if needed. Then convert into a prompt using Jinja template
|
|
||||||
conversations = self._process_messages_for_chat_template(
|
|
||||||
conversations,
|
|
||||||
batch_images=batch_images,
|
|
||||||
batch_videos=batch_videos,
|
|
||||||
batch_video_metadata=batch_video_metadata,
|
|
||||||
**processed_kwargs["mm_load_kwargs"],
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt, generation_indices = render_jinja_template(
|
prompt, generation_indices = render_jinja_template(
|
||||||
conversations=conversations,
|
conversations=conversations,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
@ -1597,11 +1558,17 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token):
|
if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token):
|
||||||
kwargs["add_special_tokens"] = False
|
kwargs["add_special_tokens"] = False
|
||||||
|
|
||||||
|
# Always sample frames by default unless explicitly set to `False` by users. If users do not pass `num_frames`/`video_fps`
|
||||||
|
# sampling should not done for BC.
|
||||||
|
if "do_sample_frames" not in kwargs and ("fps" in kwargs or "num_frames" in kwargs):
|
||||||
|
kwargs["do_sample_frames"] = True
|
||||||
|
|
||||||
out = self(
|
out = self(
|
||||||
text=prompt,
|
text=prompt,
|
||||||
images=batch_images if batch_images else None,
|
images=batch_images if batch_images else None,
|
||||||
videos=batch_videos if batch_videos else None,
|
videos=batch_videos if batch_videos else None,
|
||||||
audio=batch_audios if batch_audios else None,
|
audio=batch_audios if batch_audios else None,
|
||||||
|
video_metadata=batch_video_metadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if return_dict:
|
if return_dict:
|
||||||
@ -1626,38 +1593,6 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
return out["input_ids"]
|
return out["input_ids"]
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
|
|
||||||
# Keep private so we can simply remove when needed
|
|
||||||
def _load_video_for_model(
|
|
||||||
self,
|
|
||||||
video: Union[str, "VideoInput"],
|
|
||||||
num_frames: Optional[int] = None,
|
|
||||||
fps: Optional[int] = None,
|
|
||||||
backend: str = "opencv",
|
|
||||||
**kwargs,
|
|
||||||
) -> np.array:
|
|
||||||
"""
|
|
||||||
Loads `video` to a numpy array.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
video (`str` or `VideoInput`):
|
|
||||||
The video to convert to the numpy array format. Can be a link to video or local path.
|
|
||||||
num_frames (`int`, *optional*):
|
|
||||||
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
|
||||||
fps (`int`, *optional*):
|
|
||||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
|
||||||
If not specified and `num_frames==None`, all frames are sampled.
|
|
||||||
backend (`str`, *optional*, defaults to `"opencv"`):
|
|
||||||
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[`np.array`, Dict]: A tuple containing:
|
|
||||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
|
||||||
- Metadata dictionary.
|
|
||||||
"""
|
|
||||||
video, metadata = load_video(video, num_frames, fps=fps, backend=backend)
|
|
||||||
return video, metadata
|
|
||||||
|
|
||||||
def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
|
def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
Post-process the output of a vlm to decode the text.
|
Post-process the output of a vlm to decode the text.
|
||||||
|
@ -51,6 +51,7 @@ from .utils import (
|
|||||||
from .utils.import_utils import requires
|
from .utils.import_utils import requires
|
||||||
from .video_utils import (
|
from .video_utils import (
|
||||||
VideoInput,
|
VideoInput,
|
||||||
|
VideoMetadata,
|
||||||
group_videos_by_shape,
|
group_videos_by_shape,
|
||||||
load_video,
|
load_video,
|
||||||
make_batched_videos,
|
make_batched_videos,
|
||||||
@ -118,6 +119,14 @@ BASE_VIDEO_PROCESSOR_DOCSTRING = r"""
|
|||||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||||
do_convert_rgb (`bool`, *optional*, defaults to `self.image_std`):
|
do_convert_rgb (`bool`, *optional*, defaults to `self.image_std`):
|
||||||
Whether to convert the video to RGB.
|
Whether to convert the video to RGB.
|
||||||
|
video_metadata (`VideoMetadata`, *optional*):
|
||||||
|
Metadata of the video containing information about total duration, fps and total number of frames.
|
||||||
|
do_sample_frames (`int`, *optional*, defaults to `self.do_sample_frames`):
|
||||||
|
Whether to sample frames from the video before processing or to process the whole video.
|
||||||
|
num_frames (`int`, *optional*, defaults to `self.num_frames`):
|
||||||
|
Maximum number of frames to sample when `do_sample_frames=True`.
|
||||||
|
fps (`int`, *optional*, defaults to `self.fps`):
|
||||||
|
Target frames to sample per second when `do_sample_frames=True`.
|
||||||
return_tensors (`str` or `TensorType`, *optional*):
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
|
Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
|
||||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
@ -157,6 +166,10 @@ class BaseVideoProcessor(BaseImageProcessorFast):
|
|||||||
rescale_factor = 1 / 255
|
rescale_factor = 1 / 255
|
||||||
do_normalize = None
|
do_normalize = None
|
||||||
do_convert_rgb = None
|
do_convert_rgb = None
|
||||||
|
do_sample_frames = None
|
||||||
|
fps = None
|
||||||
|
num_frames = None
|
||||||
|
video_metadata = None
|
||||||
valid_kwargs = VideosKwargs
|
valid_kwargs = VideosKwargs
|
||||||
model_input_names = ["pixel_values_videos"]
|
model_input_names = ["pixel_values_videos"]
|
||||||
|
|
||||||
@ -219,9 +232,67 @@ class BaseVideoProcessor(BaseImageProcessorFast):
|
|||||||
video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., :3, :, :]
|
video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., :3, :, :]
|
||||||
return video
|
return video
|
||||||
|
|
||||||
|
def sample_frames(
|
||||||
|
self,
|
||||||
|
video: "torch.Tensor",
|
||||||
|
metadata: Optional[Union[VideoMetadata, dict]] = None,
|
||||||
|
num_frames: Optional[int] = None,
|
||||||
|
fps: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames.
|
||||||
|
If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames`
|
||||||
|
and `fps` are mutually exclusive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video (`torch.Tensor`):
|
||||||
|
Video that need to be sampled.
|
||||||
|
metadata (`VideoMetadata`, *optional*):
|
||||||
|
Metadata of the video containing information about total duration, fps and total number of frames.
|
||||||
|
num_frames (`int`, *optional*):
|
||||||
|
Maximum number of frames to sample. Defaults to `self.num_frames`.
|
||||||
|
fps (`int`, *optional*):
|
||||||
|
Target frames to sample per second. Defaults to `self.fps`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor:
|
||||||
|
Sampled video frames.
|
||||||
|
"""
|
||||||
|
if fps is not None and num_frames is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
|
||||||
|
)
|
||||||
|
|
||||||
|
num_frames = num_frames if num_frames is not None else self.num_frames
|
||||||
|
fps = fps if fps is not None else self.fps
|
||||||
|
total_num_frames = video.shape[0]
|
||||||
|
|
||||||
|
# If num_frames is not given but fps is, calculate num_frames from fps
|
||||||
|
if num_frames is None and fps is not None:
|
||||||
|
if metadata is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
|
||||||
|
"Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
|
||||||
|
)
|
||||||
|
num_frames = int(total_num_frames / metadata["fps"] * fps)
|
||||||
|
|
||||||
|
if num_frames > total_num_frames:
|
||||||
|
raise ValueError(
|
||||||
|
f"Video can't be sampled. The `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_frames is not None:
|
||||||
|
indices = torch.arange(0, total_num_frames, total_num_frames / num_frames).int()
|
||||||
|
else:
|
||||||
|
indices = torch.arange(0, total_num_frames).int()
|
||||||
|
|
||||||
|
video = video[indices].contiguous()
|
||||||
|
return video
|
||||||
|
|
||||||
def _prepare_input_videos(
|
def _prepare_input_videos(
|
||||||
self,
|
self,
|
||||||
videos: VideoInput,
|
videos: VideoInput,
|
||||||
|
video_metadata: VideoMetadata = None,
|
||||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
device: Optional["torch.device"] = None,
|
device: Optional["torch.device"] = None,
|
||||||
) -> List["torch.Tensor"]:
|
) -> List["torch.Tensor"]:
|
||||||
@ -229,6 +300,11 @@ class BaseVideoProcessor(BaseImageProcessorFast):
|
|||||||
Prepare the input videos for processing.
|
Prepare the input videos for processing.
|
||||||
"""
|
"""
|
||||||
videos = make_batched_videos(videos)
|
videos = make_batched_videos(videos)
|
||||||
|
if video_metadata is not None:
|
||||||
|
batch_metadata = [metadata for batch_list in video_metadata for metadata in batch_list]
|
||||||
|
else:
|
||||||
|
batch_metadata = [None] * len(videos)
|
||||||
|
|
||||||
processed_videos = []
|
processed_videos = []
|
||||||
for video in videos:
|
for video in videos:
|
||||||
# `make_batched_videos` always returns a 4D array per video
|
# `make_batched_videos` always returns a 4D array per video
|
||||||
@ -242,7 +318,7 @@ class BaseVideoProcessor(BaseImageProcessorFast):
|
|||||||
video = video.to(device)
|
video = video.to(device)
|
||||||
|
|
||||||
processed_videos.append(video)
|
processed_videos.append(video)
|
||||||
return processed_videos
|
return processed_videos, batch_metadata
|
||||||
|
|
||||||
@add_start_docstrings(BASE_VIDEO_PROCESSOR_DOCSTRING)
|
@add_start_docstrings(BASE_VIDEO_PROCESSOR_DOCSTRING)
|
||||||
def preprocess(
|
def preprocess(
|
||||||
@ -261,7 +337,10 @@ class BaseVideoProcessor(BaseImageProcessorFast):
|
|||||||
|
|
||||||
input_data_format = kwargs.pop("input_data_format")
|
input_data_format = kwargs.pop("input_data_format")
|
||||||
device = kwargs.pop("device")
|
device = kwargs.pop("device")
|
||||||
videos = self._prepare_input_videos(videos=videos, input_data_format=input_data_format, device=device)
|
video_metadata = kwargs.pop("video_metadata")
|
||||||
|
videos, video_metadata = self._prepare_input_videos(
|
||||||
|
videos=videos, video_metadata=video_metadata, input_data_format=input_data_format, device=device
|
||||||
|
)
|
||||||
|
|
||||||
kwargs = self._further_process_kwargs(**kwargs)
|
kwargs = self._further_process_kwargs(**kwargs)
|
||||||
self._validate_preprocess_kwargs(**kwargs)
|
self._validate_preprocess_kwargs(**kwargs)
|
||||||
@ -276,11 +355,12 @@ class BaseVideoProcessor(BaseImageProcessorFast):
|
|||||||
kwargs.pop("default_to_square")
|
kwargs.pop("default_to_square")
|
||||||
kwargs.pop("data_format")
|
kwargs.pop("data_format")
|
||||||
|
|
||||||
return self._preprocess(videos=videos, **kwargs)
|
return self._preprocess(videos=videos, video_metadata=video_metadata, **kwargs)
|
||||||
|
|
||||||
def _preprocess(
|
def _preprocess(
|
||||||
self,
|
self,
|
||||||
videos: List["torch.Tensor"],
|
videos: List["torch.Tensor"],
|
||||||
|
video_metadata: Union[List[VideoMetadata], List[dict]],
|
||||||
do_convert_rgb: bool,
|
do_convert_rgb: bool,
|
||||||
do_resize: bool,
|
do_resize: bool,
|
||||||
size: SizeDict,
|
size: SizeDict,
|
||||||
@ -294,8 +374,18 @@ class BaseVideoProcessor(BaseImageProcessorFast):
|
|||||||
do_normalize: bool,
|
do_normalize: bool,
|
||||||
image_mean: Optional[Union[float, List[float]]],
|
image_mean: Optional[Union[float, List[float]]],
|
||||||
image_std: Optional[Union[float, List[float]]],
|
image_std: Optional[Union[float, List[float]]],
|
||||||
|
do_sample_frames: Optional[bool] = None,
|
||||||
|
fps: Optional[int] = None,
|
||||||
|
num_frames: Optional[int] = None,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
|
if do_sample_frames:
|
||||||
|
# Sample video frames
|
||||||
|
videos = [
|
||||||
|
self.sample_frames(video, metadata=metadata, num_frames=num_frames, fps=fps)
|
||||||
|
for video, metadata in zip(videos, video_metadata)
|
||||||
|
]
|
||||||
|
|
||||||
# Group videos by size for batched resizing
|
# Group videos by size for batched resizing
|
||||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||||
resized_videos_grouped = {}
|
resized_videos_grouped = {}
|
||||||
|
@ -74,6 +74,9 @@ class VideoMetadata:
|
|||||||
duration: float
|
duration: float
|
||||||
video_backend: str
|
video_backend: str
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return getattr(self, item)
|
||||||
|
|
||||||
|
|
||||||
def is_valid_video_frame(frame):
|
def is_valid_video_frame(frame):
|
||||||
return isinstance(frame, PIL.Image.Image) or (
|
return isinstance(frame, PIL.Image.Image) or (
|
||||||
@ -163,7 +166,7 @@ def make_batched_videos(videos) -> List[Union["np.ndarray", "torch.Tensor"]]:
|
|||||||
videos = [np.array(videos)[None, ...]]
|
videos = [np.array(videos)[None, ...]]
|
||||||
# nested batch so we need to unflatten
|
# nested batch so we need to unflatten
|
||||||
elif isinstance(videos[0], (list, tuple)) and is_valid_video(videos[0][0]):
|
elif isinstance(videos[0], (list, tuple)) and is_valid_video(videos[0][0]):
|
||||||
return [video for sublist in videos for video in sublist]
|
videos = [video for sublist in videos for video in sublist]
|
||||||
return convert_pil_frames_to_video(videos)
|
return convert_pil_frames_to_video(videos)
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +17,6 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import AutoProcessor, AutoTokenizer, InternVLProcessor
|
from transformers import AutoProcessor, AutoTokenizer, InternVLProcessor
|
||||||
@ -180,77 +179,6 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
images_patches_index += inputs["pixel_values"].shape[0]
|
images_patches_index += inputs["pixel_values"].shape[0]
|
||||||
|
|
||||||
# Override video chat_template tests as InternVLProcessor returns flattened video features
|
|
||||||
@require_av
|
|
||||||
@require_torch
|
|
||||||
def test_apply_chat_template_video_special_processing(self):
|
|
||||||
"""
|
|
||||||
Tests that models can use their own preprocessing to preprocess conversations.
|
|
||||||
"""
|
|
||||||
processor = self.get_processor()
|
|
||||||
if processor.chat_template is None:
|
|
||||||
self.skipTest("Processor has no chat template")
|
|
||||||
|
|
||||||
signature = inspect.signature(processor.__call__)
|
|
||||||
if "videos" not in {*signature.parameters.keys()} or (
|
|
||||||
signature.parameters.get("videos") is not None
|
|
||||||
and signature.parameters["videos"].annotation == inspect._empty
|
|
||||||
):
|
|
||||||
self.skipTest("Processor doesn't accept videos at input")
|
|
||||||
|
|
||||||
video_file_path = hf_hub_download(
|
|
||||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
|
||||||
)
|
|
||||||
messages = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "video", "path": video_file_path},
|
|
||||||
{"type": "text", "text": "What is shown in this video?"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
def _process_messages_for_chat_template(
|
|
||||||
conversation,
|
|
||||||
batch_images,
|
|
||||||
batch_videos,
|
|
||||||
batch_video_metadata,
|
|
||||||
**chat_template_kwargs,
|
|
||||||
):
|
|
||||||
# Let us just always return a dummy prompt
|
|
||||||
new_msg = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "video"}, # no need to use path, video is loaded already by this moment
|
|
||||||
{"type": "text", "text": "Dummy prompt for preprocess testing"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
return new_msg
|
|
||||||
|
|
||||||
processor._process_messages_for_chat_template = _process_messages_for_chat_template
|
|
||||||
out_dict_with_video = processor.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
num_frames=8,
|
|
||||||
)
|
|
||||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
|
||||||
|
|
||||||
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
|
|
||||||
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
|
|
||||||
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
|
|
||||||
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default
|
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_av
|
@require_av
|
||||||
def test_apply_chat_template_video_frame_sampling(self):
|
def test_apply_chat_template_video_frame_sampling(self):
|
||||||
@ -393,13 +321,13 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
tokenize=True,
|
tokenize=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
num_frames=4, # by default no more than 4 frames, otherwise too slow
|
num_frames=2, # by default no more than 2 frames, otherwise too slow
|
||||||
)
|
)
|
||||||
self.assertTrue(self.videos_input_name in out_dict)
|
self.assertTrue(self.videos_input_name in out_dict)
|
||||||
self.assertEqual(len(out_dict["input_ids"]), batch_size)
|
self.assertEqual(len(out_dict["input_ids"]), batch_size)
|
||||||
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
|
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
|
||||||
|
|
||||||
video_len = 4 if batch_size == 1 else 3 # InternVL patches out and removes frames after processing
|
video_len = 2 if batch_size == 1 else 3 # InternVL patches out and removes frames after processing
|
||||||
self.assertEqual(len(out_dict[self.videos_input_name]), video_len)
|
self.assertEqual(len(out_dict[self.videos_input_name]), video_len)
|
||||||
for k in out_dict:
|
for k in out_dict:
|
||||||
self.assertIsInstance(out_dict[k], torch.Tensor)
|
self.assertIsInstance(out_dict[k], torch.Tensor)
|
||||||
|
@ -407,14 +407,14 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
tokenize=True,
|
tokenize=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
num_frames=4, # by default no more than 4 frames, otherwise too slow
|
num_frames=2, # by default no more than 2 frames, otherwise too slow
|
||||||
)
|
)
|
||||||
input_name = getattr(self, input_name)
|
input_name = getattr(self, input_name)
|
||||||
self.assertTrue(input_name in out_dict)
|
self.assertTrue(input_name in out_dict)
|
||||||
self.assertEqual(len(out_dict["input_ids"]), batch_size)
|
self.assertEqual(len(out_dict["input_ids"]), batch_size)
|
||||||
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
|
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
|
||||||
|
|
||||||
video_len = 5760 if batch_size == 1 else 5808 # qwen pixels don't scale with bs same way as other models
|
video_len = 2880 if batch_size == 1 else 5808 # qwen pixels don't scale with bs same way as other models
|
||||||
mm_len = batch_size * 1564 if modality == "image" else video_len
|
mm_len = batch_size * 1564 if modality == "image" else video_len
|
||||||
self.assertEqual(len(out_dict[input_name]), mm_len)
|
self.assertEqual(len(out_dict[input_name]), mm_len)
|
||||||
|
|
||||||
@ -525,72 +525,6 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 2904)
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 2904)
|
||||||
|
|
||||||
@require_av
|
|
||||||
def test_apply_chat_template_video_special_processing(self):
|
|
||||||
"""
|
|
||||||
Tests that models can use their own preprocessing to preprocess conversations.
|
|
||||||
"""
|
|
||||||
processor = self.get_processor()
|
|
||||||
if processor.chat_template is None:
|
|
||||||
self.skipTest("Processor has no chat template")
|
|
||||||
|
|
||||||
signature = inspect.signature(processor.__call__)
|
|
||||||
if "videos" not in {*signature.parameters.keys()} or (
|
|
||||||
signature.parameters.get("videos") is not None
|
|
||||||
and signature.parameters["videos"].annotation == inspect._empty
|
|
||||||
):
|
|
||||||
self.skipTest("Processor doesn't accept videos at input")
|
|
||||||
|
|
||||||
video_file_path = hf_hub_download(
|
|
||||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
|
||||||
)
|
|
||||||
messages = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "video", "path": video_file_path},
|
|
||||||
{"type": "text", "text": "What is shown in this video?"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
def _process_messages_for_chat_template(
|
|
||||||
conversation,
|
|
||||||
batch_images,
|
|
||||||
batch_videos,
|
|
||||||
batch_video_metadata,
|
|
||||||
**chat_template_kwargs,
|
|
||||||
):
|
|
||||||
# Let us just always return a dummy prompt
|
|
||||||
new_msg = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "video"}, # no need to use path, video is loaded already by this moment
|
|
||||||
{"type": "text", "text": "Dummy prompt for preprocess testing"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
return new_msg
|
|
||||||
|
|
||||||
processor._process_messages_for_chat_template = _process_messages_for_chat_template
|
|
||||||
out_dict_with_video = processor.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True,
|
|
||||||
)
|
|
||||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
|
||||||
|
|
||||||
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
|
|
||||||
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
|
|
||||||
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
|
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 145912)
|
|
||||||
|
|
||||||
@require_librosa
|
@require_librosa
|
||||||
@require_av
|
@require_av
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
|
@ -19,7 +19,6 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
|
|
||||||
from transformers import AutoProcessor, Qwen2Tokenizer
|
from transformers import AutoProcessor, Qwen2Tokenizer
|
||||||
from transformers.testing_utils import require_av, require_torch, require_vision
|
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||||
@ -219,14 +218,14 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
tokenize=True,
|
tokenize=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
num_frames=4, # by default no more than 4 frames, otherwise too slow
|
num_frames=2, # by default no more than 2 frames, otherwise too slow
|
||||||
)
|
)
|
||||||
input_name = getattr(self, input_name)
|
input_name = getattr(self, input_name)
|
||||||
self.assertTrue(input_name in out_dict)
|
self.assertTrue(input_name in out_dict)
|
||||||
self.assertEqual(len(out_dict["input_ids"]), batch_size)
|
self.assertEqual(len(out_dict["input_ids"]), batch_size)
|
||||||
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
|
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
|
||||||
|
|
||||||
video_len = 360 if batch_size == 1 else 320 # qwen pixels don't scale with bs same way as other models
|
video_len = 180 if batch_size == 1 else 320 # qwen pixels don't scale with bs same way as other models
|
||||||
mm_len = batch_size * 192 if modality == "image" else video_len
|
mm_len = batch_size * 192 if modality == "image" else video_len
|
||||||
self.assertEqual(len(out_dict[input_name]), mm_len)
|
self.assertEqual(len(out_dict[input_name]), mm_len)
|
||||||
|
|
||||||
@ -346,70 +345,3 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
self.assertEqual(inputs[self.images_input_name].shape[0], 612)
|
self.assertEqual(inputs[self.images_input_name].shape[0], 612)
|
||||||
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
|
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
|
||||||
self.assertEqual(inputs[self.images_input_name].shape[0], 100)
|
self.assertEqual(inputs[self.images_input_name].shape[0], 100)
|
||||||
|
|
||||||
@require_av
|
|
||||||
def test_apply_chat_template_video_special_processing(self):
|
|
||||||
"""
|
|
||||||
Tests that models can use their own preprocessing to preprocess conversations.
|
|
||||||
"""
|
|
||||||
processor = self.get_processor()
|
|
||||||
if processor.chat_template is None:
|
|
||||||
self.skipTest("Processor has no chat template")
|
|
||||||
|
|
||||||
signature = inspect.signature(processor.__call__)
|
|
||||||
if "videos" not in {*signature.parameters.keys()} or (
|
|
||||||
signature.parameters.get("videos") is not None
|
|
||||||
and signature.parameters["videos"].annotation == inspect._empty
|
|
||||||
):
|
|
||||||
self.skipTest("Processor doesn't accept videos at input")
|
|
||||||
|
|
||||||
video_file_path = hf_hub_download(
|
|
||||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
|
||||||
)
|
|
||||||
messages = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "video", "path": video_file_path},
|
|
||||||
{"type": "text", "text": "What is shown in this video?"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
def _process_messages_for_chat_template(
|
|
||||||
conversation,
|
|
||||||
batch_images,
|
|
||||||
batch_videos,
|
|
||||||
batch_video_metadata,
|
|
||||||
**chat_template_kwargs,
|
|
||||||
):
|
|
||||||
# Let us just always return a dummy prompt
|
|
||||||
new_msg = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "video"}, # no need to use path, video is loaded already by this moment
|
|
||||||
{"type": "text", "text": "Dummy prompt for preprocess testing"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
return new_msg
|
|
||||||
|
|
||||||
processor._process_messages_for_chat_template = _process_messages_for_chat_template
|
|
||||||
out_dict_with_video = processor.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
|
||||||
|
|
||||||
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
|
|
||||||
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
|
|
||||||
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
|
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 21960)
|
|
||||||
|
@ -19,7 +19,6 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
|
|
||||||
from transformers import AutoProcessor, Qwen2Tokenizer
|
from transformers import AutoProcessor, Qwen2Tokenizer
|
||||||
from transformers.testing_utils import require_av, require_torch, require_vision
|
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||||
@ -220,14 +219,14 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
tokenize=True,
|
tokenize=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
num_frames=4, # by default no more than 4 frames, otherwise too slow
|
num_frames=2, # by default no more than 2 frames, otherwise too slow
|
||||||
)
|
)
|
||||||
input_name = getattr(self, input_name)
|
input_name = getattr(self, input_name)
|
||||||
self.assertTrue(input_name in out_dict)
|
self.assertTrue(input_name in out_dict)
|
||||||
self.assertEqual(len(out_dict["input_ids"]), batch_size)
|
self.assertEqual(len(out_dict["input_ids"]), batch_size)
|
||||||
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
|
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
|
||||||
|
|
||||||
video_len = 360 if batch_size == 1 else 320 # qwen pixels don't scale with bs same way as other models
|
video_len = 180 if batch_size == 1 else 320 # qwen pixels don't scale with bs same way as other models
|
||||||
mm_len = batch_size * 192 if modality == "image" else video_len
|
mm_len = batch_size * 192 if modality == "image" else video_len
|
||||||
self.assertEqual(len(out_dict[input_name]), mm_len)
|
self.assertEqual(len(out_dict[input_name]), mm_len)
|
||||||
|
|
||||||
@ -337,73 +336,6 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 160)
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 160)
|
||||||
|
|
||||||
@require_av
|
|
||||||
def test_apply_chat_template_video_special_processing(self):
|
|
||||||
"""
|
|
||||||
Tests that models can use their own preprocessing to preprocess conversations.
|
|
||||||
"""
|
|
||||||
processor = self.get_processor()
|
|
||||||
if processor.chat_template is None:
|
|
||||||
self.skipTest("Processor has no chat template")
|
|
||||||
|
|
||||||
signature = inspect.signature(processor.__call__)
|
|
||||||
if "videos" not in {*signature.parameters.keys()} or (
|
|
||||||
signature.parameters.get("videos") is not None
|
|
||||||
and signature.parameters["videos"].annotation == inspect._empty
|
|
||||||
):
|
|
||||||
self.skipTest("Processor doesn't accept videos at input")
|
|
||||||
|
|
||||||
video_file_path = hf_hub_download(
|
|
||||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
|
||||||
)
|
|
||||||
messages = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "video", "path": video_file_path},
|
|
||||||
{"type": "text", "text": "What is shown in this video?"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
def _process_messages_for_chat_template(
|
|
||||||
conversation,
|
|
||||||
batch_images,
|
|
||||||
batch_videos,
|
|
||||||
batch_video_metadata,
|
|
||||||
**chat_template_kwargs,
|
|
||||||
):
|
|
||||||
# Let us just always return a dummy prompt
|
|
||||||
new_msg = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "video"}, # no need to use path, video is loaded already by this moment
|
|
||||||
{"type": "text", "text": "Dummy prompt for preprocess testing"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
return new_msg
|
|
||||||
|
|
||||||
processor._process_messages_for_chat_template = _process_messages_for_chat_template
|
|
||||||
out_dict_with_video = processor.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
|
||||||
|
|
||||||
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
|
|
||||||
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
|
|
||||||
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
|
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 21960)
|
|
||||||
|
|
||||||
def test_kwargs_overrides_custom_image_processor_kwargs(self):
|
def test_kwargs_overrides_custom_image_processor_kwargs(self):
|
||||||
processor = self.get_processor()
|
processor = self.get_processor()
|
||||||
self.skip_processor_without_typed_kwargs(processor)
|
self.skip_processor_without_typed_kwargs(processor)
|
||||||
|
@ -99,8 +99,9 @@ class Qwen2VLVideoProcessingTester:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
def expected_output_video_shape(self, videos):
|
def expected_output_video_shape(self, videos, num_frames=None):
|
||||||
grid_t = self.num_frames // self.temporal_patch_size
|
num_frames = num_frames if num_frames is not None else self.num_frames
|
||||||
|
grid_t = num_frames // self.temporal_patch_size
|
||||||
hidden_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size
|
hidden_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size
|
||||||
seq_len = 0
|
seq_len = 0
|
||||||
for video in videos:
|
for video in videos:
|
||||||
@ -289,3 +290,70 @@ class Qwen2VLVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
|
|||||||
)[self.input_name]
|
)[self.input_name]
|
||||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
||||||
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
|
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
|
||||||
|
|
||||||
|
def test_call_sample_frames(self):
|
||||||
|
for video_processing_class in self.video_processor_list:
|
||||||
|
video_processing = video_processing_class(**self.video_processor_dict)
|
||||||
|
|
||||||
|
prev_num_frames = self.video_processor_tester.num_frames
|
||||||
|
self.video_processor_tester.num_frames = 8
|
||||||
|
video_inputs = self.video_processor_tester.prepare_video_inputs(
|
||||||
|
equal_resolution=False,
|
||||||
|
return_tensors="torch",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force set sampling to False. No sampling is expected even when `num_frames` exists
|
||||||
|
video_processing.do_sample_frames = False
|
||||||
|
|
||||||
|
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=3)[self.input_name]
|
||||||
|
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=3)[self.input_name]
|
||||||
|
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
|
||||||
|
expected_output_video_shape_batched = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
||||||
|
self.assertListEqual(list(encoded_videos.shape), expected_output_video_shape)
|
||||||
|
self.assertListEqual(list(encoded_videos_batched.shape), expected_output_video_shape_batched)
|
||||||
|
|
||||||
|
# Set sampling to True. Video frames should be sampled with `num_frames` in the output
|
||||||
|
video_processing.do_sample_frames = True
|
||||||
|
|
||||||
|
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=4)[self.input_name]
|
||||||
|
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=4)[self.input_name]
|
||||||
|
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(
|
||||||
|
[video_inputs[0]], num_frames=4
|
||||||
|
)
|
||||||
|
expected_output_video_shape_batched = self.video_processor_tester.expected_output_video_shape(
|
||||||
|
video_inputs, num_frames=4
|
||||||
|
)
|
||||||
|
self.assertListEqual(list(encoded_videos.shape), expected_output_video_shape)
|
||||||
|
self.assertListEqual(list(encoded_videos_batched.shape), expected_output_video_shape_batched)
|
||||||
|
|
||||||
|
# Sample with `fps` requires metadata to infer number of frames from total duration
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", fps=3)[self.input_name]
|
||||||
|
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", fps=3)[self.input_name]
|
||||||
|
|
||||||
|
metadata = [[{"duration": 2.0, "total_num_frames": 8, "fps": 4}]]
|
||||||
|
batched_metadata = metadata * len(video_inputs)
|
||||||
|
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", fps=3, video_metadata=metadata)[
|
||||||
|
self.input_name
|
||||||
|
]
|
||||||
|
encoded_videos_batched = video_processing(
|
||||||
|
video_inputs, return_tensors="pt", fps=3, video_metadata=batched_metadata
|
||||||
|
)[self.input_name]
|
||||||
|
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(
|
||||||
|
[video_inputs[0]], num_frames=6
|
||||||
|
)
|
||||||
|
expected_output_video_shape_batched = self.video_processor_tester.expected_output_video_shape(
|
||||||
|
video_inputs, num_frames=6
|
||||||
|
)
|
||||||
|
self.assertListEqual(list(encoded_videos.shape), expected_output_video_shape)
|
||||||
|
self.assertListEqual(list(encoded_videos_batched.shape), expected_output_video_shape_batched)
|
||||||
|
|
||||||
|
# We should raise error when asked to sample more frames than there are in input video
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=10)[self.input_name]
|
||||||
|
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=10)[
|
||||||
|
self.input_name
|
||||||
|
]
|
||||||
|
|
||||||
|
# Assign back the actual num frames in tester
|
||||||
|
self.video_processor_tester.num_frames = prev_num_frames
|
||||||
|
@ -16,6 +16,7 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@ -63,7 +64,7 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
cls.bos_token = processor.tokenizer.bos_token
|
cls.bos_token = processor.tokenizer.bos_token
|
||||||
cls.image_token = processor.image_token
|
cls.image_token = processor.image_token
|
||||||
cls.video_token = processor.image_token * 8 # SmolVLM uses image token and repeats it `num_frames` times
|
cls.video_token = processor.video_token
|
||||||
cls.fake_image_token = processor.fake_image_token
|
cls.fake_image_token = processor.fake_image_token
|
||||||
cls.global_img_token = processor.global_image_token
|
cls.global_img_token = processor.global_image_token
|
||||||
|
|
||||||
@ -93,6 +94,13 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
"chat_template": "<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
"chat_template": "<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def prepare_video_inputs(self, batch_size: Optional[int] = None):
|
||||||
|
"""This function prepares a list of numpy videos."""
|
||||||
|
video_input = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] * 8
|
||||||
|
if batch_size is None:
|
||||||
|
return [[video_input]]
|
||||||
|
return [[video_input]] * batch_size
|
||||||
|
|
||||||
def get_split_image_expected_tokens(self, processor, image_rows, image_cols):
|
def get_split_image_expected_tokens(self, processor, image_rows, image_cols):
|
||||||
text_split_images = []
|
text_split_images = []
|
||||||
for n_h in range(image_rows):
|
for n_h in range(image_rows):
|
||||||
@ -347,7 +355,6 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
{"type": "text", "text": "What do these images show?"},
|
{"type": "text", "text": "What do these images show?"},
|
||||||
{"type": "image"},
|
{"type": "image"},
|
||||||
{"type": "image"},
|
{"type": "image"},
|
||||||
"What do these images show?",
|
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -373,11 +380,8 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(rendered, expected_rendered)
|
self.assertEqual(rendered, expected_rendered)
|
||||||
|
|
||||||
@unittest.skip(reason="SmolVLM replaced `type=video` with `type=image` in chat templates")
|
|
||||||
def test_apply_chat_template_video_special_processing(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@require_av
|
@require_av
|
||||||
|
@require_torch
|
||||||
def test_apply_chat_template_video_frame_sampling(self):
|
def test_apply_chat_template_video_frame_sampling(self):
|
||||||
# overridden because SmolVLM has special preprocessing for videos
|
# overridden because SmolVLM has special preprocessing for videos
|
||||||
processor = self.get_processor()
|
processor = self.get_processor()
|
||||||
@ -406,7 +410,7 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
tokenize=True,
|
tokenize=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
num_frames=num_frames,
|
num_frames=num_frames,
|
||||||
return_tensors="np",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
@ -421,7 +425,7 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
tokenize=True,
|
tokenize=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
video_fps=video_fps,
|
video_fps=video_fps,
|
||||||
return_tensors="np",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
@ -482,11 +486,11 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
do_rescale=True,
|
do_rescale=True,
|
||||||
rescale_factor=-1,
|
rescale_factor=-1,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
max_length=76,
|
max_length=172,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
||||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
self.assertEqual(len(inputs["input_ids"][0]), 172)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_vision
|
@require_vision
|
||||||
|
@ -15,22 +15,16 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
|
from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
from transformers.utils import is_torchvision_available, is_vision_available
|
||||||
|
|
||||||
from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs
|
from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
if is_torchvision_available():
|
if is_torchvision_available():
|
||||||
from transformers import SmolVLMVideoProcessor
|
from transformers import SmolVLMVideoProcessor
|
||||||
from transformers.models.smolvlm.video_processing_smolvlm import get_resize_output_image_size
|
|
||||||
|
|
||||||
|
|
||||||
class SmolVLMVideoProcessingTester:
|
class SmolVLMVideoProcessingTester:
|
||||||
@ -58,6 +52,7 @@ class SmolVLMVideoProcessingTester:
|
|||||||
self.max_resolution = max_resolution
|
self.max_resolution = max_resolution
|
||||||
self.do_resize = do_resize
|
self.do_resize = do_resize
|
||||||
self.size = size
|
self.size = size
|
||||||
|
self.max_image_size = size
|
||||||
self.do_normalize = do_normalize
|
self.do_normalize = do_normalize
|
||||||
self.image_mean = image_mean
|
self.image_mean = image_mean
|
||||||
self.image_std = image_std
|
self.image_std = image_std
|
||||||
@ -71,17 +66,16 @@ class SmolVLMVideoProcessingTester:
|
|||||||
"image_mean": self.image_mean,
|
"image_mean": self.image_mean,
|
||||||
"image_std": self.image_std,
|
"image_std": self.image_std,
|
||||||
"do_convert_rgb": self.do_convert_rgb,
|
"do_convert_rgb": self.do_convert_rgb,
|
||||||
|
"max_image_size": self.max_image_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
def expected_output_video_shape(self, videos):
|
def expected_output_video_shape(self, videos):
|
||||||
max_height, max_width = 0, 0
|
return [
|
||||||
if not isinstance(videos[0], torch.Tensor):
|
self.num_frames,
|
||||||
videos = [torch.tensor(np.array(video)).permute(0, -1, -3, -2) for video in videos]
|
self.num_channels,
|
||||||
for video in videos:
|
self.max_image_size["longest_edge"],
|
||||||
height, width = get_resize_output_image_size(video, self.size["longest_edge"])
|
self.max_image_size["longest_edge"],
|
||||||
max_height = max(height, max_height)
|
]
|
||||||
max_width = max(width, max_width)
|
|
||||||
return [self.num_frames, self.num_channels, max_height, max_width]
|
|
||||||
|
|
||||||
def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"):
|
def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"):
|
||||||
videos = prepare_video_inputs(
|
videos = prepare_video_inputs(
|
||||||
@ -116,3 +110,58 @@ class SmolVLMVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
|
|||||||
|
|
||||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict, size=42)
|
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict, size=42)
|
||||||
self.assertEqual(video_processor.size, {"height": 42, "width": 42})
|
self.assertEqual(video_processor.size, {"height": 42, "width": 42})
|
||||||
|
|
||||||
|
# overwrite, SmolVLM requires to have metadata no matter how we sample
|
||||||
|
def test_call_sample_frames(self):
|
||||||
|
for video_processing_class in self.video_processor_list:
|
||||||
|
video_processing = video_processing_class(**self.video_processor_dict)
|
||||||
|
|
||||||
|
prev_num_frames = self.video_processor_tester.num_frames
|
||||||
|
self.video_processor_tester.num_frames = 8
|
||||||
|
video_inputs = self.video_processor_tester.prepare_video_inputs(
|
||||||
|
equal_resolution=False,
|
||||||
|
return_tensors="torch",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force set sampling to False. No sampling is expected even when `num_frames` exists
|
||||||
|
video_processing.do_sample_frames = False
|
||||||
|
|
||||||
|
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=3)[self.input_name]
|
||||||
|
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=3)[self.input_name]
|
||||||
|
self.assertEqual(encoded_videos.shape[1], 8)
|
||||||
|
self.assertEqual(encoded_videos_batched.shape[1], 8)
|
||||||
|
|
||||||
|
# Set sampling to True. Video frames should be sampled with `num_frames` in the output
|
||||||
|
video_processing.do_sample_frames = True
|
||||||
|
metadata = [[{"duration": 2.0, "total_num_frames": 8, "fps": 4}]]
|
||||||
|
batched_metadata = metadata * len(video_inputs)
|
||||||
|
|
||||||
|
# Sample with `fps` requires metadata to infer number of frames from total duration
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=6, fps=3)[
|
||||||
|
self.input_name
|
||||||
|
]
|
||||||
|
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=6, fps=3)[
|
||||||
|
self.input_name
|
||||||
|
]
|
||||||
|
|
||||||
|
encoded_videos = video_processing(
|
||||||
|
video_inputs[0], return_tensors="pt", num_frames=6, fps=3, video_metadata=metadata
|
||||||
|
)[self.input_name]
|
||||||
|
encoded_videos_batched = video_processing(
|
||||||
|
video_inputs, return_tensors="pt", num_frames=6, fps=3, video_metadata=batched_metadata
|
||||||
|
)[self.input_name]
|
||||||
|
self.assertEqual(encoded_videos.shape[1], 6)
|
||||||
|
self.assertEqual(encoded_videos_batched.shape[1], 6)
|
||||||
|
|
||||||
|
# We should raise error when asked to sample more frames than there are in input video
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", fps=10, num_frames=20)[
|
||||||
|
self.input_name
|
||||||
|
]
|
||||||
|
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", fps=10, num_frames=20)[
|
||||||
|
self.input_name
|
||||||
|
]
|
||||||
|
|
||||||
|
# Assign back the actual num frames in tester
|
||||||
|
self.video_processor_tester.num_frames = prev_num_frames
|
||||||
|
@ -507,7 +507,7 @@ class ProcessorTesterMixin:
|
|||||||
if "video_processor" not in self.processor_class.attributes:
|
if "video_processor" not in self.processor_class.attributes:
|
||||||
self.skipTest(f"video_processor attribute not present in {self.processor_class}")
|
self.skipTest(f"video_processor attribute not present in {self.processor_class}")
|
||||||
processor_components = self.prepare_components()
|
processor_components = self.prepare_components()
|
||||||
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
|
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=167, padding="max_length")
|
||||||
processor_kwargs = self.prepare_processor_dict()
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = self.processor_class(**processor_components, **processor_kwargs)
|
processor = self.processor_class(**processor_components, **processor_kwargs)
|
||||||
@ -515,7 +515,7 @@ class ProcessorTesterMixin:
|
|||||||
input_str = self.prepare_text_inputs(modality="video")
|
input_str = self.prepare_text_inputs(modality="video")
|
||||||
video_input = self.prepare_video_inputs()
|
video_input = self.prepare_video_inputs()
|
||||||
inputs = processor(text=input_str, videos=video_input, return_tensors="pt")
|
inputs = processor(text=input_str, videos=video_input, return_tensors="pt")
|
||||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 117)
|
self.assertEqual(inputs[self.text_input_name].shape[-1], 167)
|
||||||
|
|
||||||
def test_video_processor_defaults_preserved_by_video_kwargs(self):
|
def test_video_processor_defaults_preserved_by_video_kwargs(self):
|
||||||
"""
|
"""
|
||||||
@ -529,7 +529,7 @@ class ProcessorTesterMixin:
|
|||||||
processor_components["video_processor"] = self.get_component(
|
processor_components["video_processor"] = self.get_component(
|
||||||
"video_processor", do_rescale=True, rescale_factor=-1
|
"video_processor", do_rescale=True, rescale_factor=-1
|
||||||
)
|
)
|
||||||
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
|
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=167, padding="max_length")
|
||||||
processor_kwargs = self.prepare_processor_dict()
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = self.processor_class(**processor_components, **processor_kwargs)
|
processor = self.processor_class(**processor_components, **processor_kwargs)
|
||||||
@ -553,9 +553,9 @@ class ProcessorTesterMixin:
|
|||||||
input_str = self.prepare_text_inputs(modality="video")
|
input_str = self.prepare_text_inputs(modality="video")
|
||||||
video_input = self.prepare_video_inputs()
|
video_input = self.prepare_video_inputs()
|
||||||
inputs = processor(
|
inputs = processor(
|
||||||
text=input_str, videos=video_input, return_tensors="pt", max_length=112, padding="max_length"
|
text=input_str, videos=video_input, return_tensors="pt", max_length=162, padding="max_length"
|
||||||
)
|
)
|
||||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 112)
|
self.assertEqual(inputs[self.text_input_name].shape[-1], 162)
|
||||||
|
|
||||||
def test_kwargs_overrides_default_video_processor_kwargs(self):
|
def test_kwargs_overrides_default_video_processor_kwargs(self):
|
||||||
if "video_processor" not in self.processor_class.attributes:
|
if "video_processor" not in self.processor_class.attributes:
|
||||||
@ -564,7 +564,7 @@ class ProcessorTesterMixin:
|
|||||||
processor_components["video_processor"] = self.get_component(
|
processor_components["video_processor"] = self.get_component(
|
||||||
"video_processor", do_rescale=True, rescale_factor=1
|
"video_processor", do_rescale=True, rescale_factor=1
|
||||||
)
|
)
|
||||||
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
|
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=167, padding="max_length")
|
||||||
processor_kwargs = self.prepare_processor_dict()
|
processor_kwargs = self.prepare_processor_dict()
|
||||||
|
|
||||||
processor = self.processor_class(**processor_components, **processor_kwargs)
|
processor = self.processor_class(**processor_components, **processor_kwargs)
|
||||||
@ -593,11 +593,11 @@ class ProcessorTesterMixin:
|
|||||||
do_rescale=True,
|
do_rescale=True,
|
||||||
rescale_factor=-1,
|
rescale_factor=-1,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
max_length=76,
|
max_length=176,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
||||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|
self.assertEqual(inputs[self.text_input_name].shape[-1], 176)
|
||||||
|
|
||||||
def test_unstructured_kwargs_batched_video(self):
|
def test_unstructured_kwargs_batched_video(self):
|
||||||
if "video_processor" not in self.processor_class.attributes:
|
if "video_processor" not in self.processor_class.attributes:
|
||||||
@ -616,13 +616,13 @@ class ProcessorTesterMixin:
|
|||||||
do_rescale=True,
|
do_rescale=True,
|
||||||
rescale_factor=-1,
|
rescale_factor=-1,
|
||||||
padding="longest",
|
padding="longest",
|
||||||
max_length=76,
|
max_length=176,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
len(inputs[self.text_input_name][0]) == len(inputs[self.text_input_name][1])
|
len(inputs[self.text_input_name][0]) == len(inputs[self.text_input_name][1])
|
||||||
and len(inputs[self.text_input_name][1]) < 76
|
and len(inputs[self.text_input_name][1]) < 176
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_doubly_passed_kwargs_video(self):
|
def test_doubly_passed_kwargs_video(self):
|
||||||
@ -659,14 +659,14 @@ class ProcessorTesterMixin:
|
|||||||
all_kwargs = {
|
all_kwargs = {
|
||||||
"common_kwargs": {"return_tensors": "pt"},
|
"common_kwargs": {"return_tensors": "pt"},
|
||||||
"videos_kwargs": {"do_rescale": True, "rescale_factor": -1},
|
"videos_kwargs": {"do_rescale": True, "rescale_factor": -1},
|
||||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
"text_kwargs": {"padding": "max_length", "max_length": 176},
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs = processor(text=input_str, videos=video_input, **all_kwargs)
|
inputs = processor(text=input_str, videos=video_input, **all_kwargs)
|
||||||
self.skip_processor_without_typed_kwargs(processor)
|
self.skip_processor_without_typed_kwargs(processor)
|
||||||
|
|
||||||
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
||||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|
self.assertEqual(inputs[self.text_input_name].shape[-1], 176)
|
||||||
|
|
||||||
def test_structured_kwargs_nested_from_dict_video(self):
|
def test_structured_kwargs_nested_from_dict_video(self):
|
||||||
if "video_processor" not in self.processor_class.attributes:
|
if "video_processor" not in self.processor_class.attributes:
|
||||||
@ -682,12 +682,12 @@ class ProcessorTesterMixin:
|
|||||||
all_kwargs = {
|
all_kwargs = {
|
||||||
"common_kwargs": {"return_tensors": "pt"},
|
"common_kwargs": {"return_tensors": "pt"},
|
||||||
"videos_kwargs": {"do_rescale": True, "rescale_factor": -1},
|
"videos_kwargs": {"do_rescale": True, "rescale_factor": -1},
|
||||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
"text_kwargs": {"padding": "max_length", "max_length": 176},
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs = processor(text=input_str, videos=video_input, **all_kwargs)
|
inputs = processor(text=input_str, videos=video_input, **all_kwargs)
|
||||||
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
||||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|
self.assertEqual(inputs[self.text_input_name].shape[-1], 176)
|
||||||
|
|
||||||
# TODO: the same test, but for audio + text processors that have strong overlap in kwargs
|
# TODO: the same test, but for audio + text processors that have strong overlap in kwargs
|
||||||
# TODO (molbap) use the same structure of attribute kwargs for other tests to avoid duplication
|
# TODO (molbap) use the same structure of attribute kwargs for other tests to avoid duplication
|
||||||
@ -884,7 +884,7 @@ class ProcessorTesterMixin:
|
|||||||
tokenize=True,
|
tokenize=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
num_frames=4, # by default no more than 4 frames, otherwise too slow
|
num_frames=2, # by default no more than 2 frames, otherwise too slow
|
||||||
)
|
)
|
||||||
input_name = getattr(self, input_name)
|
input_name = getattr(self, input_name)
|
||||||
self.assertTrue(input_name in out_dict)
|
self.assertTrue(input_name in out_dict)
|
||||||
@ -983,6 +983,21 @@ class ProcessorTesterMixin:
|
|||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), video_fps * 10)
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), video_fps * 10)
|
||||||
|
|
||||||
|
# Whan `do_sample_frames=False` no sampling is done and whole video is loaded, even if number of frames is passed
|
||||||
|
video_fps = 1
|
||||||
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
do_sample_frames=False,
|
||||||
|
video_fps=video_fps,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 300)
|
||||||
|
|
||||||
# Load with `video_fps` and `num_frames` args, should raise an error
|
# Load with `video_fps` and `num_frames` args, should raise an error
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
out_dict_with_video = processor.apply_chat_template(
|
out_dict_with_video = processor.apply_chat_template(
|
||||||
@ -1024,75 +1039,6 @@ class ProcessorTesterMixin:
|
|||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
|
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
|
||||||
|
|
||||||
@require_av
|
|
||||||
@require_torch
|
|
||||||
def test_apply_chat_template_video_special_processing(self):
|
|
||||||
"""
|
|
||||||
Tests that models can use their own preprocessing to preprocess conversations.
|
|
||||||
"""
|
|
||||||
processor = self.get_processor()
|
|
||||||
if processor.chat_template is None:
|
|
||||||
self.skipTest("Processor has no chat template")
|
|
||||||
|
|
||||||
signature = inspect.signature(processor.__call__)
|
|
||||||
if "videos" not in {*signature.parameters.keys()} or (
|
|
||||||
signature.parameters.get("videos") is not None
|
|
||||||
and signature.parameters["videos"].annotation == inspect._empty
|
|
||||||
):
|
|
||||||
self.skipTest("Processor doesn't accept videos at input")
|
|
||||||
|
|
||||||
video_file_path = hf_hub_download(
|
|
||||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
|
||||||
)
|
|
||||||
messages = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "video", "path": video_file_path},
|
|
||||||
{"type": "text", "text": "What is shown in this video?"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
def _process_messages_for_chat_template(
|
|
||||||
conversation,
|
|
||||||
batch_images,
|
|
||||||
batch_videos,
|
|
||||||
batch_video_metadata,
|
|
||||||
**chat_template_kwargs,
|
|
||||||
):
|
|
||||||
# Let us just always return a dummy prompt
|
|
||||||
new_msg = [
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "video"}, # no need to use path, video is loaded already by this moment
|
|
||||||
{"type": "text", "text": "Dummy prompt for preprocess testing"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
return new_msg
|
|
||||||
|
|
||||||
processor._process_messages_for_chat_template = _process_messages_for_chat_template
|
|
||||||
out_dict_with_video = processor.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
|
||||||
|
|
||||||
# Check with `in` because we don't know how each template formats the prompt with BOS/EOS/etc
|
|
||||||
formatted_text = processor.batch_decode(out_dict_with_video["input_ids"], skip_special_tokens=True)[0]
|
|
||||||
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
|
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
|
||||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243)
|
|
||||||
|
|
||||||
@require_librosa
|
@require_librosa
|
||||||
@require_av
|
@require_av
|
||||||
def test_chat_template_audio_from_video(self):
|
def test_chat_template_audio_from_video(self):
|
||||||
|
@ -293,6 +293,59 @@ class VideoProcessingTestMixin:
|
|||||||
(self.video_processor_tester.batch_size, *expected_output_video_shape),
|
(self.video_processor_tester.batch_size, *expected_output_video_shape),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_call_sample_frames(self):
|
||||||
|
for video_processing_class in self.video_processor_list:
|
||||||
|
video_processing = video_processing_class(**self.video_processor_dict)
|
||||||
|
|
||||||
|
prev_num_frames = self.video_processor_tester.num_frames
|
||||||
|
self.video_processor_tester.num_frames = 8
|
||||||
|
video_inputs = self.video_processor_tester.prepare_video_inputs(
|
||||||
|
equal_resolution=False,
|
||||||
|
return_tensors="torch",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force set sampling to False. No sampling is expected even when `num_frames` exists
|
||||||
|
video_processing.do_sample_frames = False
|
||||||
|
|
||||||
|
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=3)[self.input_name]
|
||||||
|
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=3)[self.input_name]
|
||||||
|
self.assertEqual(encoded_videos.shape[1], 8)
|
||||||
|
self.assertEqual(encoded_videos_batched.shape[1], 8)
|
||||||
|
|
||||||
|
# Set sampling to True. Video frames should be sampled with `num_frames` in the output
|
||||||
|
video_processing.do_sample_frames = True
|
||||||
|
|
||||||
|
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=3)[self.input_name]
|
||||||
|
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=3)[self.input_name]
|
||||||
|
self.assertEqual(encoded_videos.shape[1], 3)
|
||||||
|
self.assertEqual(encoded_videos_batched.shape[1], 3)
|
||||||
|
|
||||||
|
# Sample with `fps` requires metadata to infer number of frames from total duration
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", fps=3)[self.input_name]
|
||||||
|
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", fps=3)[self.input_name]
|
||||||
|
|
||||||
|
metadata = [[{"duration": 2.0, "total_num_frames": 8, "fps": 4}]]
|
||||||
|
batched_metadata = metadata * len(video_inputs)
|
||||||
|
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", fps=3, video_metadata=metadata)[
|
||||||
|
self.input_name
|
||||||
|
]
|
||||||
|
encoded_videos_batched = video_processing(
|
||||||
|
video_inputs, return_tensors="pt", fps=3, video_metadata=batched_metadata
|
||||||
|
)[self.input_name]
|
||||||
|
self.assertEqual(encoded_videos.shape[1], 6)
|
||||||
|
self.assertEqual(encoded_videos_batched.shape[1], 6)
|
||||||
|
|
||||||
|
# We should raise error when asked to sample more frames than there are in input video
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
encoded_videos = video_processing(video_inputs[0], return_tensors="pt", num_frames=10)[self.input_name]
|
||||||
|
encoded_videos_batched = video_processing(video_inputs, return_tensors="pt", num_frames=10)[
|
||||||
|
self.input_name
|
||||||
|
]
|
||||||
|
|
||||||
|
# Assign back the actual num frames in tester
|
||||||
|
self.video_processor_tester.num_frames = prev_num_frames
|
||||||
|
|
||||||
def test_nested_input(self):
|
def test_nested_input(self):
|
||||||
"""Tests that the processor can work with nested list where each video is a list of arrays"""
|
"""Tests that the processor can work with nested list where each video is a list of arrays"""
|
||||||
for video_processing_class in self.video_processor_list:
|
for video_processing_class in self.video_processor_list:
|
||||||
|
Loading…
Reference in New Issue
Block a user