diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0dc62cc6f95..44c9a75aa79 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -39,6 +39,8 @@ title: Tokenizers - local: image_processors title: Image processors + - local: video_processors + title: Video processors - local: backbones title: Backbones - local: feature_extractors @@ -362,7 +364,9 @@ title: Feature Extractor - local: main_classes/image_processor title: Image Processor - title: Main classes + - local: main_classes/video_processor + title: Video Processor + title: Main Classes - sections: - sections: - local: model_doc/albert diff --git a/docs/source/en/image_processors.md b/docs/source/en/image_processors.md index 2e5e466cd5d..feb568bdd3b 100644 --- a/docs/source/en/image_processors.md +++ b/docs/source/en/image_processors.md @@ -16,7 +16,7 @@ rendered properly in your Markdown viewer. # Image processors -Image processors converts images into pixel values, tensors that represent image colors and size. The pixel values are inputs to a vision or video model. To ensure a pretrained model receives the correct input, an image processor can perform the following operations to make sure an image is exactly like the images a model was pretrained on. +Image processors converts images into pixel values, tensors that represent image colors and size. The pixel values are inputs to a vision model. To ensure a pretrained model receives the correct input, an image processor can perform the following operations to make sure an image is exactly like the images a model was pretrained on. - [`~BaseImageProcessor.center_crop`] to resize an image - [`~BaseImageProcessor.normalize`] or [`~BaseImageProcessor.rescale`] pixel values diff --git a/docs/source/en/main_classes/video_processor.md b/docs/source/en/main_classes/video_processor.md new file mode 100644 index 00000000000..bdff30e9c50 --- /dev/null +++ b/docs/source/en/main_classes/video_processor.md @@ -0,0 +1,55 @@ + + + +# Video Processor + +A **Video Processor** is a utility responsible for preparing input features for video models, as well as handling the post-processing of their outputs. It provides transformations such as resizing, normalization, and conversion into PyTorch. + +The video processor extends the functionality of image processors by allowing Vision Large Language Models (VLMs) to handle videos with a distinct set of arguments compared to images. It serves as the bridge between raw video data and the model, ensuring that input features are optimized for the VLM. + +When adding a new VLM or updating an existing one to enable distinct video preprocessing, saving and reloading the processor configuration will store the video related arguments in a dedicated file named `video_preprocessing_config.json`. Don't worry if you haven't upadted your VLM, the processor will try to load video related configurations from a file named `preprocessing_config.json`. + + +### Usage Example +Here's an example of how to load a video processor with [`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf) model: + +```python +from transformers import AutoVideoProcessor + +processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf") +``` + +Currently, if using base image processor for videos, it processes video data by treating each frame as an individual image and applying transformations frame-by-frame. While functional, this approach is not highly efficient. Using `AutoVideoProcessor` allows us to take advantage of **fast video processors**, leveraging the [torchvision](https://pytorch.org/vision/stable/index.html) library. Fast processors handle the whole batch of videos at once, without iterating over each video or frame. These updates introduce GPU acceleration and significantly enhance processing speed, especially for tasks requiring high throughput. + +Fast video processors are available for all models and are loaded by default when an `AutoVideoProcessor` is initialized. When using a fast video processor, you can also set the `device` argument to specify the device on which the processing should be done. By default, the processing is done on the same device as the inputs if the inputs are tensors, or on the CPU otherwise. For even more speed improvement, we can compile the processor when using 'cuda' as device. + +```python +import torch +from transformers.video_utils import load_video +from transformers import AutoVideoProcessor + +video = load_video("video.mp4") +processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device="cuda") +processor = torch.compile(processor) +processed_video = processor(video, return_tensors="pt") +``` + + +## BaseVideoProcessor + +[[autodoc]] video_processing_utils.BaseVideoProcessor + diff --git a/docs/source/en/model_doc/auto.md b/docs/source/en/model_doc/auto.md index 05931285087..afe343228f2 100644 --- a/docs/source/en/model_doc/auto.md +++ b/docs/source/en/model_doc/auto.md @@ -74,6 +74,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its [[autodoc]] AutoImageProcessor +## AutoVideoProcessor + +[[autodoc]] AutoVideoProcessor + ## AutoProcessor [[autodoc]] AutoProcessor diff --git a/docs/source/en/model_doc/instructblipvideo.md b/docs/source/en/model_doc/instructblipvideo.md index c021a4c7afa..fd728c35bb8 100644 --- a/docs/source/en/model_doc/instructblipvideo.md +++ b/docs/source/en/model_doc/instructblipvideo.md @@ -58,6 +58,12 @@ The attributes can be obtained from model config, as `model.config.num_query_tok [[autodoc]] InstructBlipVideoProcessor + +## InstructBlipVideoVideoProcessor + +[[autodoc]] InstructBlipVideoVideoProcessor + - preprocess + ## InstructBlipVideoImageProcessor [[autodoc]] InstructBlipVideoImageProcessor diff --git a/docs/source/en/model_doc/internvl.md b/docs/source/en/model_doc/internvl.md index 8a19726e3e0..97802cb94e2 100644 --- a/docs/source/en/model_doc/internvl.md +++ b/docs/source/en/model_doc/internvl.md @@ -353,3 +353,7 @@ This example showcases how to handle a batch of chat conversations with interlea ## InternVLProcessor [[autodoc]] InternVLProcessor + +## InternVLVideoProcessor + +[[autodoc]] InternVLVideoProcessor diff --git a/docs/source/en/model_doc/llava_next_video.md b/docs/source/en/model_doc/llava_next_video.md index d3306256846..aa611211629 100644 --- a/docs/source/en/model_doc/llava_next_video.md +++ b/docs/source/en/model_doc/llava_next_video.md @@ -262,6 +262,10 @@ model = LlavaNextVideoForConditionalGeneration.from_pretrained( [[autodoc]] LlavaNextVideoImageProcessor +## LlavaNextVideoVideoProcessor + +[[autodoc]] LlavaNextVideoVideoProcessor + ## LlavaNextVideoModel [[autodoc]] LlavaNextVideoModel diff --git a/docs/source/en/model_doc/llava_onevision.md b/docs/source/en/model_doc/llava_onevision.md index a00dd5a0e12..e265177590b 100644 --- a/docs/source/en/model_doc/llava_onevision.md +++ b/docs/source/en/model_doc/llava_onevision.md @@ -303,6 +303,7 @@ model = LlavaOnevisionForConditionalGeneration.from_pretrained( ## LlavaOnevisionImageProcessor [[autodoc]] LlavaOnevisionImageProcessor + - preprocess ## LlavaOnevisionImageProcessorFast @@ -313,6 +314,10 @@ model = LlavaOnevisionForConditionalGeneration.from_pretrained( [[autodoc]] LlavaOnevisionVideoProcessor +## LlavaOnevisionVideoProcessor + +[[autodoc]] LlavaOnevisionVideoProcessor + ## LlavaOnevisionModel [[autodoc]] LlavaOnevisionModel diff --git a/docs/source/en/model_doc/qwen2_vl.md b/docs/source/en/model_doc/qwen2_vl.md index 7fef4e2fdbd..c6bf692f9de 100644 --- a/docs/source/en/model_doc/qwen2_vl.md +++ b/docs/source/en/model_doc/qwen2_vl.md @@ -287,6 +287,11 @@ model = Qwen2VLForConditionalGeneration.from_pretrained( [[autodoc]] Qwen2VLImageProcessor - preprocess +## Qwen2VLVideoProcessor + +[[autodoc]] Qwen2VLVideoProcessor + - preprocess + ## Qwen2VLImageProcessorFast [[autodoc]] Qwen2VLImageProcessorFast diff --git a/docs/source/en/model_doc/smolvlm.md b/docs/source/en/model_doc/smolvlm.md index 9512fb6aa29..d5062b4df66 100644 --- a/docs/source/en/model_doc/smolvlm.md +++ b/docs/source/en/model_doc/smolvlm.md @@ -197,6 +197,9 @@ print(generated_texts[0]) [[autodoc]] SmolVLMImageProcessor - preprocess +## SmolVLMVideoProcessor +[[autodoc]] SmolVLMVideoProcessor + - preprocess ## SmolVLMProcessor [[autodoc]] SmolVLMProcessor diff --git a/docs/source/en/model_doc/video_llava.md b/docs/source/en/model_doc/video_llava.md index ca1a06d4cdc..9eaed2e7d56 100644 --- a/docs/source/en/model_doc/video_llava.md +++ b/docs/source/en/model_doc/video_llava.md @@ -211,6 +211,11 @@ model = VideoLlavaForConditionalGeneration.from_pretrained( [[autodoc]] VideoLlavaImageProcessor + +## VideoLlavaVideoProcessor + +[[autodoc]] VideoLlavaVideoProcessor + ## VideoLlavaProcessor [[autodoc]] VideoLlavaProcessor diff --git a/docs/source/en/video_processors.md b/docs/source/en/video_processors.md new file mode 100644 index 00000000000..4f44914c8cf --- /dev/null +++ b/docs/source/en/video_processors.md @@ -0,0 +1,49 @@ + + + +# Video Processor + +A **Video Processor** is a utility responsible for preparing input features for video models, as well as handling the post-processing of their outputs. It provides transformations such as resizing, normalization, and conversion into PyTorch. + +The video processor extends the functionality of image processors by allowing the models to handle videos with a distinct set of arguments compared to images. It serves as the bridge between raw video data and the model, ensuring that input features are optimized for the VLM. + +Use [`~BaseVideoProcessor.from_pretrained`] to load a video processors configuration (image size, whether to normalize and rescale, etc.) from a video model on the Hugging Face [Hub](https://hf.co) or local directory. The configuration for each pretrained model should be saved in a [video_preprocessor_config.json] file but older models might have the config saved in [preprocessor_config.json](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf/blob/main/preprocessor_config.json) file. Note that the latter is less preferred and will be removed in the future. + + +### Usage Example +Here's an example of how to load a video processor with [`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf) model: + +```python +from transformers import AutoVideoProcessor + +processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf") +``` + +Currently, if using base image processor for videos, it processes video data by treating each frame as an individual image and applying transformations frame-by-frame. While functional, this approach is not highly efficient. Using `AutoVideoProcessor` allows us to take advantage of **fast video processors**, leveraging the [torchvision](https://pytorch.org/vision/stable/index.html) library. Fast processors handle the whole batch of videos at once, without iterating over each video or frame. These updates introduce GPU acceleration and significantly enhance processing speed, especially for tasks requiring high throughput. + +Fast video processors are available for all models and are loaded by default when an `AutoVideoProcessor` is initialized. When using a fast video processor, you can also set the `device` argument to specify the device on which the processing should be done. By default, the processing is done on the same device as the inputs if the inputs are tensors, or on the CPU otherwise. For even more speed improvement, we can compile the processor when using 'cuda' as device. + +```python +import torch +from transformers.video_utils import load_video +from transformers import AutoVideoProcessor + +video = load_video("video.mp4") +processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device="cuda") +processor = torch.compile(processor) +processed_video = processor(video, return_tensors="pt") +``` diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 691f8aad00d..e2d20fb1276 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -276,6 +276,7 @@ _import_structure = { "TorchAoConfig", "VptqConfig", ], + "video_utils": [], } # tokenizers-backed objects @@ -334,6 +335,7 @@ except OptionalDependencyNotAvailable: ] else: _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"] + _import_structure["video_processing_utils"] = ["BaseVideoProcessor"] # PyTorch-backed objects try: @@ -809,6 +811,7 @@ if TYPE_CHECKING: from .utils.dummy_torchvision_objects import * else: from .image_processing_utils_fast import BaseImageProcessorFast + from .video_processing_utils import BaseVideoProcessor try: if not (is_torchvision_available() and is_timm_available()): diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index 064122bfa7e..feb254f66a3 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -249,7 +249,7 @@ class BaseImageProcessorFast(BaseImageProcessor): Image to resize. size (`SizeDict`): Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. - resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`. Returns: diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index c8a7edd985c..7f9be1c671c 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -56,7 +56,9 @@ def to_channel_dimension_format( input_channel_dim: Optional[Union[ChannelDimension, str]] = None, ) -> np.ndarray: """ - Converts `image` to the channel dimension format specified by `channel_dim`. + Converts `image` to the channel dimension format specified by `channel_dim`. The input + can have arbitrary number of leading dimensions. Only last three dimension will be permuted + to format the `image`. Args: image (`numpy.ndarray`): @@ -80,9 +82,11 @@ def to_channel_dimension_format( return image if target_channel_dim == ChannelDimension.FIRST: - image = image.transpose((2, 0, 1)) + axes = list(range(image.ndim - 3)) + [image.ndim - 1, image.ndim - 3, image.ndim - 2] + image = image.transpose(axes) elif target_channel_dim == ChannelDimension.LAST: - image = image.transpose((1, 2, 0)) + axes = list(range(image.ndim - 3)) + [image.ndim - 2, image.ndim - 1, image.ndim - 3] + image = image.transpose(axes) else: raise ValueError(f"Unsupported channel dimension format: {channel_dim}") diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index f584e6b82c8..c8c6ab79f2b 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -15,11 +15,9 @@ import base64 import os from collections.abc import Iterable -from contextlib import redirect_stdout from dataclasses import dataclass from io import BytesIO -from typing import Callable, Optional, Union -from urllib.parse import urlparse +from typing import Optional, Union import numpy as np import requests @@ -27,9 +25,6 @@ from packaging import version from .utils import ( ExplicitEnum, - is_av_available, - is_cv2_available, - is_decord_available, is_jax_tensor, is_numpy_array, is_tf_tensor, @@ -37,7 +32,6 @@ from .utils import ( is_torch_tensor, is_torchvision_available, is_vision_available, - is_yt_dlp_available, logging, requires_backends, to_numpy, @@ -62,7 +56,6 @@ if is_vision_available(): PILImageResampling = PIL.Image if is_torchvision_available(): - from torchvision import io as torchvision_io from torchvision.transforms import InterpolationMode pil_torch_interpolation_mapping = { @@ -89,18 +82,6 @@ ImageInput = Union[ ] # noqa -VideoInput = Union[ - list["PIL.Image.Image"], - "np.ndarray", - "torch.Tensor", - list["np.ndarray"], - list["torch.Tensor"], - list[list["PIL.Image.Image"]], - list[list["np.ndarray"]], - list[list["torch.Tensor"]], -] # noqa - - class ChannelDimension(ExplicitEnum): FIRST = "channels_first" LAST = "channels_last" @@ -116,14 +97,6 @@ class AnnotionFormat(ExplicitEnum): COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value -@dataclass -class VideoMetadata: - total_num_frames: int - fps: float - duration: float - video_backend: str - - AnnotationType = dict[str, Union[int, str, list[dict]]] @@ -309,37 +282,6 @@ def make_nested_list_of_images( raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.") -def make_batched_videos(videos) -> VideoInput: - """ - Ensure that the input is a list of videos. - Args: - videos (`VideoInput`): - Video or videos to turn into a list of videos. - Returns: - list: A list of videos. - """ - if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): - # case 1: nested batch of videos so we flatten it - if not is_pil_image(videos[0][0]) and videos[0][0].ndim == 4: - videos = [[video for batch_list in batched_videos for video in batch_list] for batched_videos in videos] - # case 2: list of videos represented as list of video frames - return videos - - elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): - if is_pil_image(videos[0]) or videos[0].ndim == 3: - return [videos] - elif videos[0].ndim == 4: - return [list(video) for video in videos] - - elif is_valid_image(videos): - if is_pil_image(videos) or videos.ndim == 3: - return [[videos]] - elif videos.ndim == 4: - return [list(videos)] - - raise ValueError(f"Could not make batched video from {videos}") - - def to_numpy_array(img) -> np.ndarray: if not is_valid_image(img): raise ValueError(f"Invalid image type: {type(img)}") @@ -371,6 +313,8 @@ def infer_channel_dimension_format( first_dim, last_dim = 0, 2 elif image.ndim == 4: first_dim, last_dim = 1, 3 + elif image.ndim == 5: + first_dim, last_dim = 2, 4 else: raise ValueError(f"Unsupported number of image dimensions: {image.ndim}") @@ -548,348 +492,6 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = return image -def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs): - """ - A default sampling function that replicates the logic used in get_uniform_frame_indices, - while optionally handling `fps` if `num_frames` is not provided. - - Args: - metadata (`VideoMetadata`): - `VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps". - num_frames (`int`, *optional*): - Number of frames to sample uniformly. - fps (`int`, *optional*): - Desired frames per second. Takes priority over num_frames if both are provided. - - Returns: - `np.ndarray`: Array of frame indices to sample. - """ - total_num_frames = metadata.total_num_frames - video_fps = metadata.fps - - # If num_frames is not given but fps is, calculate num_frames from fps - if num_frames is None and fps is not None: - num_frames = int(total_num_frames / video_fps * fps) - if num_frames > total_num_frames: - raise ValueError( - f"When loading the video with fps={fps}, we computed num_frames={num_frames} " - f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata." - ) - - if num_frames is not None: - indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int) - else: - indices = np.arange(0, total_num_frames, dtype=int) - return indices - - -def read_video_opencv( - video_path: str, - sample_indices_fn: Callable, - **kwargs, -): - """ - Decode a video using the OpenCV backend. - - Args: - video_path (`str`): - Path to the video file. - sample_indices_fn (`Callable`): - A callable function that will return indices at which the video should be sampled. If the video has to be loaded using - by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. - If not provided, simple uniform sampling with fps is performed. - Example: - def sample_indices_fn(metadata, **kwargs): - return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) - - Returns: - Tuple[`np.array`, `VideoMetadata`]: A tuple containing: - - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - - `VideoMetadata` object. - """ - # Lazy import cv2 - requires_backends(read_video_opencv, ["cv2"]) - import cv2 - - video = cv2.VideoCapture(video_path) - total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) - video_fps = video.get(cv2.CAP_PROP_FPS) - duration = total_num_frames / video_fps if video_fps else 0 - metadata = VideoMetadata( - total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv" - ) - indices = sample_indices_fn(metadata=metadata, **kwargs) - - index = 0 - frames = [] - while video.isOpened(): - success, frame = video.read() - if not success: - break - if index in indices: - height, width, channel = frame.shape - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frames.append(frame[0:height, 0:width, 0:channel]) - if success: - index += 1 - if index >= total_num_frames: - break - - video.release() - metadata.frames_indices = indices - return np.stack(frames), metadata - - -def read_video_decord( - video_path: str, - sample_indices_fn: Optional[Callable] = None, - **kwargs, -): - """ - Decode a video using the Decord backend. - - Args: - video_path (`str`): - Path to the video file. - sample_indices_fn (`Callable`, *optional*): - A callable function that will return indices at which the video should be sampled. If the video has to be loaded using - by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. - If not provided, simple uniform sampling with fps is performed. - Example: - def sample_indices_fn(metadata, **kwargs): - return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) - - Returns: - Tuple[`np.array`, `VideoMetadata`]: A tuple containing: - - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - - `VideoMetadata` object. - """ - # Lazy import from decord - requires_backends(read_video_decord, ["decord"]) - from decord import VideoReader, cpu - - vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu - video_fps = vr.get_avg_fps() - total_num_frames = len(vr) - duration = total_num_frames / video_fps if video_fps else 0 - metadata = VideoMetadata( - total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord" - ) - - indices = sample_indices_fn(metadata=metadata, **kwargs) - - frames = vr.get_batch(indices).asnumpy() - metadata.frames_indices = indices - return frames, metadata - - -def read_video_pyav( - video_path: str, - sample_indices_fn: Callable, - **kwargs, -): - """ - Decode the video with PyAV decoder. - - Args: - video_path (`str`): - Path to the video file. - sample_indices_fn (`Callable`, *optional*): - A callable function that will return indices at which the video should be sampled. If the video has to be loaded using - by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. - If not provided, simple uniform sampling with fps is performed. - Example: - def sample_indices_fn(metadata, **kwargs): - return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) - - Returns: - Tuple[`np.array`, `VideoMetadata`]: A tuple containing: - - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - - `VideoMetadata` object. - """ - # Lazy import av - requires_backends(read_video_pyav, ["av"]) - import av - - container = av.open(video_path) - total_num_frames = container.streams.video[0].frames - video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`? - duration = total_num_frames / video_fps if video_fps else 0 - metadata = VideoMetadata( - total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav" - ) - indices = sample_indices_fn(metadata=metadata, **kwargs) - - frames = [] - container.seek(0) - end_index = indices[-1] - for i, frame in enumerate(container.decode(video=0)): - if i > end_index: - break - if i >= 0 and i in indices: - frames.append(frame) - - video = np.stack([x.to_ndarray(format="rgb24") for x in frames]) - metadata.frames_indices = indices - return video, metadata - - -def read_video_torchvision( - video_path: str, - sample_indices_fn: Callable, - **kwargs, -): - """ - Decode the video with torchvision decoder. - - Args: - video_path (`str`): - Path to the video file. - sample_indices_fn (`Callable`, *optional*): - A callable function that will return indices at which the video should be sampled. If the video has to be loaded using - by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. - If not provided, simple uniform sampling with fps is performed. - Example: - def sample_indices_fn(metadata, **kwargs): - return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) - - Returns: - Tuple[`np.array`, `VideoMetadata`]: A tuple containing: - - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - - `VideoMetadata` object. - """ - video, _, info = torchvision_io.read_video( - video_path, - start_pts=0.0, - end_pts=None, - pts_unit="sec", - output_format="THWC", - ) - video_fps = info["video_fps"] - total_num_frames = video.size(0) - duration = total_num_frames / video_fps if video_fps else 0 - metadata = VideoMetadata( - total_num_frames=int(total_num_frames), - fps=float(video_fps), - duration=float(duration), - video_backend="torchvision", - ) - - indices = sample_indices_fn(metadata=metadata, **kwargs) - - video = video[indices].contiguous().numpy() - metadata.frames_indices = indices - return video, metadata - - -VIDEO_DECODERS = { - "decord": read_video_decord, - "opencv": read_video_opencv, - "pyav": read_video_pyav, - "torchvision": read_video_torchvision, -} - - -def load_video( - video: Union[str, "VideoInput"], - num_frames: Optional[int] = None, - fps: Optional[int] = None, - backend: str = "opencv", - sample_indices_fn: Optional[Callable] = None, - **kwargs, -) -> np.array: - """ - Loads `video` to a numpy array. - - Args: - video (`str` or `VideoInput`): - The video to convert to the numpy array format. Can be a link to video or local path. - num_frames (`int`, *optional*): - Number of frames to sample uniformly. If not passed, the whole video is loaded. - fps (`int`, *optional*): - Number of frames to sample per second. Should be passed only when `num_frames=None`. - If not specified and `num_frames==None`, all frames are sampled. - backend (`str`, *optional*, defaults to `"opencv"`): - The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv". - sample_indices_fn (`Callable`, *optional*): - A callable function that will return indices at which the video should be sampled. If the video has to be loaded using - by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. - If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args. - The function expects at input the all args along with all kwargs passed to `load_video` and should output valid - indices at which the video should be sampled. For example: - - Example: - def sample_indices_fn(metadata, **kwargs): - return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int) - - Returns: - Tuple[`np.array`, Dict]: A tuple containing: - - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]). - - Metadata dictionary. - """ - - # If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn` - if fps is not None and num_frames is not None and sample_indices_fn is None: - raise ValueError( - "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!" - ) - - # If user didn't pass a sampling function, create one on the fly with default logic - if sample_indices_fn is None: - - def sample_indices_fn_func(metadata, **fn_kwargs): - return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs) - - sample_indices_fn = sample_indices_fn_func - - if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]: - if not is_yt_dlp_available(): - raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.") - # Lazy import from yt_dlp - requires_backends(load_video, ["yt_dlp"]) - from yt_dlp import YoutubeDL - - buffer = BytesIO() - with redirect_stdout(buffer), YoutubeDL() as f: - f.download([video]) - bytes_obj = buffer.getvalue() - file_obj = BytesIO(bytes_obj) - elif video.startswith("http://") or video.startswith("https://"): - file_obj = BytesIO(requests.get(video).content) - elif os.path.isfile(video): - file_obj = video - elif is_valid_image(video) or (isinstance(video, (list, tuple)) and is_valid_image(video[0])): - file_obj = None - else: - raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.") - - # can also load with decord, but not cv2/torchvision - # both will fail in case of url links - video_is_url = video.startswith("http://") or video.startswith("https://") - if video_is_url and backend in ["opencv", "torchvision"]: - raise ValueError( - "If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend" - ) - - if file_obj is None: - return video - - if ( - (not is_decord_available() and backend == "decord") - or (not is_av_available() and backend == "pyav") - or (not is_cv2_available() and backend == "opencv") - or (not is_torchvision_available() and backend == "torchvision") - ): - raise ImportError( - f"You chose backend={backend} for loading the video but the required library is not found in your environment " - f"Make sure to install {backend} before loading the video." - ) - - video_decoder = VIDEO_DECODERS[backend] - video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs) - return video, metadata - - def load_images( images: Union[list, tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None ) -> Union["PIL.Image.Image", list["PIL.Image.Image"], list[list["PIL.Image.Image"]]]: diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 6828030287e..34a6ae1e5c2 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: from .modeling_tf_auto import * from .processing_auto import * from .tokenization_auto import * + from .video_processing_auto import * else: import sys diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index d49301cd5f4..d43463720a4 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -28,7 +28,14 @@ from ...feature_extraction_utils import FeatureExtractionMixin from ...image_processing_utils import ImageProcessingMixin from ...processing_utils import ProcessorMixin from ...tokenization_utils import TOKENIZER_CONFIG_FILE -from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, cached_file, logging +from ...utils import ( + FEATURE_EXTRACTOR_NAME, + PROCESSOR_NAME, + VIDEO_PROCESSOR_NAME, + cached_file, + logging, +) +from ...video_processing_utils import BaseVideoProcessor from .auto_factory import _LazyAutoMapping from .configuration_auto import ( CONFIG_MAPPING_NAMES, @@ -295,14 +302,31 @@ class AutoProcessor: if "AutoProcessor" in config_dict.get("auto_map", {}): processor_auto_map = config_dict["auto_map"]["AutoProcessor"] - # If not found, let's check whether the processor class is saved in a feature extractor config - if preprocessor_config_file is not None and processor_class is None: - config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict( - pretrained_model_name_or_path, **kwargs + # Saved as video processor + if preprocessor_config_file is None: + preprocessor_config_file = cached_file( + pretrained_model_name_or_path, VIDEO_PROCESSOR_NAME, **cached_file_kwargs ) - processor_class = config_dict.get("processor_class", None) - if "AutoProcessor" in config_dict.get("auto_map", {}): - processor_auto_map = config_dict["auto_map"]["AutoProcessor"] + if preprocessor_config_file is not None: + config_dict, _ = BaseVideoProcessor.get_video_processor_dict( + pretrained_model_name_or_path, **kwargs + ) + processor_class = config_dict.get("processor_class", None) + if "AutoProcessor" in config_dict.get("auto_map", {}): + processor_auto_map = config_dict["auto_map"]["AutoProcessor"] + + # Saved as feature extractor + if preprocessor_config_file is None: + preprocessor_config_file = cached_file( + pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs + ) + if preprocessor_config_file is not None and processor_class is None: + config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict( + pretrained_model_name_or_path, **kwargs + ) + processor_class = config_dict.get("processor_class", None) + if "AutoProcessor" in config_dict.get("auto_map", {}): + processor_auto_map = config_dict["auto_map"]["AutoProcessor"] if processor_class is None: # Next, let's check whether the processor class is saved in a tokenizer diff --git a/src/transformers/models/auto/video_processing_auto.py b/src/transformers/models/auto/video_processing_auto.py new file mode 100644 index 00000000000..46cd1289cfa --- /dev/null +++ b/src/transformers/models/auto/video_processing_auto.py @@ -0,0 +1,384 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AutoVideoProcessor class.""" + +import importlib +import json +import os +import warnings +from collections import OrderedDict +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union + +# Build the list of all video processors +from ...configuration_utils import PretrainedConfig +from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from ...utils import ( + CONFIG_NAME, + VIDEO_PROCESSOR_NAME, + cached_file, + is_torchvision_available, + logging, +) +from ...utils.import_utils import requires +from ...video_processing_utils import BaseVideoProcessor +from .auto_factory import _LazyAutoMapping +from .configuration_auto import ( + CONFIG_MAPPING_NAMES, + AutoConfig, + model_type_to_module_name, + replace_list_option_in_docstrings, +) + + +logger = logging.get_logger(__name__) + + +if TYPE_CHECKING: + # This significantly improves completion suggestion performance when + # the transformers package is used with Microsoft's Pylance language server. + VIDEO_PROCESSOR_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict() +else: + VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict( + [ + ("instructblipvideo", "InstructBlipVideoVideoProcessor"), + ("llava_next_video", "LlavaNextVideoVideoProcessor"), + ("llava_onevision", "LlavaOnevisionVideoProcessor"), + ("qwen2_5_vl", "Qwen2_5_VLVideoProcessor"), + ("qwen2_vl", "Qwen2VLVideoProcessor"), + ("video_llava", "VideoLlavaVideoProcessor"), + ] + ) + +for model_type, video_processors in VIDEO_PROCESSOR_MAPPING_NAMES.items(): + fast_video_processor_class = video_processors + + # If the torchvision is not available, we set it to None + if not is_torchvision_available(): + fast_video_processor_class = None + + VIDEO_PROCESSOR_MAPPING_NAMES[model_type] = fast_video_processor_class + +VIDEO_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, VIDEO_PROCESSOR_MAPPING_NAMES) + + +def video_processor_class_from_name(class_name: str): + for module_name, extractors in VIDEO_PROCESSOR_MAPPING_NAMES.items(): + if class_name in extractors: + module_name = model_type_to_module_name(module_name) + + module = importlib.import_module(f".{module_name}", "transformers.models") + try: + return getattr(module, class_name) + except AttributeError: + continue + + for _, extractor in VIDEO_PROCESSOR_MAPPING._extra_content.items(): + if getattr(extractor, "__name__", None) == class_name: + return extractor + + # We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main + # init and we return the proper dummy to get an appropriate error message. + main_module = importlib.import_module("transformers") + if hasattr(main_module, class_name): + return getattr(main_module, class_name) + + return None + + +def get_video_processor_config( + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: Optional[bool] = None, + proxies: Optional[Dict[str, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + **kwargs, +): + """ + Loads the video processor configuration from a pretrained model video processor configuration. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a configuration file saved using the + [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the video processor configuration from local files. + + + + Passing `token=True` is required when you want to use a private model. + + + + Returns: + `Dict`: The configuration of the video processor. + + Examples: + + ```python + # Download configuration from huggingface.co and cache. + video_processor_config = get_video_processor_config("llava-hf/llava-onevision-qwen2-0.5b-ov-hf") + # This model does not have a video processor config so the result will be an empty dict. + video_processor_config = get_video_processor_config("FacebookAI/xlm-roberta-base") + + # Save a pretrained video processor locally and you can reload its config + from transformers import AutoVideoProcessor + + video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf") + video_processor.save_pretrained("video-processor-test") + video_processor = get_video_processor_config("video-processor-test") + ```""" + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + token = use_auth_token + + resolved_config_file = cached_file( + pretrained_model_name_or_path, + VIDEO_PROCESSOR_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + ) + if resolved_config_file is None: + logger.info( + "Could not locate the video processor configuration file, will try to use the model config instead." + ) + return {} + + with open(resolved_config_file, encoding="utf-8") as reader: + return json.load(reader) + + +@requires(backends=("vision", "torchvision")) +class AutoVideoProcessor: + r""" + This is a generic video processor class that will be instantiated as one of the video processor classes of the + library when created with the [`AutoVideoProcessor.from_pretrained`] class method. + + This class cannot be instantiated directly using `__init__()` (throws an error). + """ + + def __init__(self): + raise EnvironmentError( + "AutoVideoProcessor is designed to be instantiated " + "using the `AutoVideoProcessor.from_pretrained(pretrained_model_name_or_path)` method." + ) + + @classmethod + @replace_list_option_in_docstrings(VIDEO_PROCESSOR_MAPPING_NAMES) + def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + r""" + Instantiate one of the video processor classes of the library from a pretrained model vocabulary. + + The video processor class to instantiate is selected based on the `model_type` property of the config object + (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's + missing, by falling back to using pattern matching on `pretrained_model_name_or_path`: + + List options + + Params: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained video_processor hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a video processor file saved using the + [`~video_processing_utils.BaseVideoProcessor.save_pretrained`] method, e.g., + `./my_model_directory/`. + - a path or url to a saved video processor JSON *file*, e.g., + `./my_model_directory/preprocessor_config.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model video processor should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the video processor files and override the cached versions if + they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final video processor object. If `True`, then this + functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary + consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of + `kwargs` which has not been used to update `video_processor` and is otherwise ignored. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + kwargs (`Dict[str, Any]`, *optional*): + The values in kwargs of any keys which are video processor attributes will be used to override the + loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is + controlled by the `return_unused_kwargs` keyword parameter. + + + + Passing `token=True` is required when you want to use a private model. + + + + Examples: + + ```python + >>> from transformers import AutoVideoProcessor + + >>> # Download video processor from huggingface.co and cache. + >>> video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf") + + >>> # If video processor files are in a directory (e.g. video processor was saved using *save_pretrained('./test/saved_model/')*) + >>> # video_processor = AutoVideoProcessor.from_pretrained("./test/saved_model/") + ```""" + use_auth_token = kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if kwargs.get("token", None) is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + kwargs["token"] = use_auth_token + + config = kwargs.pop("config", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + kwargs["_from_auto"] = True + + config_dict, _ = BaseVideoProcessor.get_video_processor_dict(pretrained_model_name_or_path, **kwargs) + video_processor_class = config_dict.get("video_processor_type", None) + video_processor_auto_map = None + if "AutoVideoProcessor" in config_dict.get("auto_map", {}): + video_processor_auto_map = config_dict["auto_map"]["AutoVideoProcessor"] + + # If we still don't have the video processor class, check if we're loading from a previous feature extractor config + # and if so, infer the video processor class from there. + if video_processor_class is None and video_processor_auto_map is None: + feature_extractor_class = config_dict.pop("feature_extractor_type", None) + if feature_extractor_class is not None: + video_processor_class = feature_extractor_class.replace("FeatureExtractor", "VideoProcessor") + if "AutoFeatureExtractor" in config_dict.get("auto_map", {}): + feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"] + video_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "VideoProcessor") + + # If we don't find the video processor class in the video processor config, let's try the model config. + if video_processor_class is None and video_processor_auto_map is None: + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + # It could be in `config.video_processor_type`` + video_processor_class = getattr(config, "video_processor_type", None) + if hasattr(config, "auto_map") and "AutoVideoProcessor" in config.auto_map: + video_processor_auto_map = config.auto_map["AutoVideoProcessor"] + + if video_processor_class is not None: + video_processor_class = video_processor_class_from_name(video_processor_class) + + has_remote_code = video_processor_auto_map is not None + has_local_code = video_processor_class is not None or type(config) in VIDEO_PROCESSOR_MAPPING + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code + ) + + if has_remote_code and trust_remote_code: + class_ref = video_processor_auto_map + video_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) + _ = kwargs.pop("code_revision", None) + if os.path.isdir(pretrained_model_name_or_path): + video_processor_class.register_for_auto_class() + return video_processor_class.from_dict(config_dict, **kwargs) + elif video_processor_class is not None: + return video_processor_class.from_dict(config_dict, **kwargs) + # Last try: we use the VIDEO_PROCESSOR_MAPPING. + elif type(config) in VIDEO_PROCESSOR_MAPPING: + video_processor_class = VIDEO_PROCESSOR_MAPPING[type(config)] + + if video_processor_class is not None: + return video_processor_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + else: + raise ValueError( + "This video processor cannot be instantiated. Please make sure you have `torchvision` installed." + ) + + raise ValueError( + f"Unrecognized video processor in {pretrained_model_name_or_path}. Should have a " + f"`video_processor_type` key in its {VIDEO_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following " + f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in VIDEO_PROCESSOR_MAPPING_NAMES.keys())}" + ) + + @staticmethod + def register( + config_class, + video_processor_class, + exist_ok=False, + ): + """ + Register a new video processor for this class. + + Args: + config_class ([`PretrainedConfig`]): + The configuration corresponding to the model to register. + video_processor_class ([`BaseVideoProcessor`]): + The video processor to register. + """ + VIDEO_PROCESSOR_MAPPING.register(config_class, video_processor_class, exist_ok=exist_ok) + + +__all__ = ["VIDEO_PROCESSOR_MAPPING", "AutoVideoProcessor"] diff --git a/src/transformers/models/emu3/image_processing_emu3.py b/src/transformers/models/emu3/image_processing_emu3.py index 3780de93c36..3cbe6fe5dc8 100644 --- a/src/transformers/models/emu3/image_processing_emu3.py +++ b/src/transformers/models/emu3/image_processing_emu3.py @@ -27,7 +27,6 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - VideoInput, get_image_size, infer_channel_dimension_format, is_scaled_image, @@ -166,7 +165,7 @@ class Emu3ImageProcessor(BaseImageProcessor): def _preprocess( self, - images: Union[ImageInput, VideoInput], + images: ImageInput, do_resize: Optional[bool] = None, resample: PILImageResampling = None, do_rescale: Optional[bool] = None, diff --git a/src/transformers/models/instructblipvideo/__init__.py b/src/transformers/models/instructblipvideo/__init__.py index 816c6b23052..2eb06450487 100644 --- a/src/transformers/models/instructblipvideo/__init__.py +++ b/src/transformers/models/instructblipvideo/__init__.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from .image_processing_instructblipvideo import * from .modeling_instructblipvideo import * from .processing_instructblipvideo import * + from .video_processing_instructblipvideo import * else: import sys diff --git a/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py index 32018a79542..436ce86eb43 100644 --- a/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/image_processing_instructblipvideo.py @@ -29,20 +29,20 @@ from ...image_utils import ( ChannelDimension, ImageInput, PILImageResampling, - VideoInput, infer_channel_dimension_format, is_scaled_image, - make_batched_videos, to_numpy_array, valid_images, validate_preprocess_arguments, ) from ...utils import TensorType, filter_out_non_signature_kwargs, logging +from ...video_utils import VideoInput, make_batched_videos logger = logging.get_logger(__name__) +# TODO (raushan): processor can be removed after v5 release. Kept for backwards compatibility # Copied from transformers.models.blip.image_processing_blip.BlipImageProcessor with Blip->InstructBlipVideo, BLIP->InstructBLIPVideo class InstructBlipVideoImageProcessor(BaseImageProcessor): r""" @@ -236,6 +236,10 @@ class InstructBlipVideoImageProcessor(BaseImageProcessor): size = get_size_dict(size, default_to_square=False) videos = make_batched_videos(images) + logger.warning( + "`InstructBlipVideoImageProcessor` is deprecated and will be removed in v5.0. " + "We recommend to load an instance of `InstructBlipVideoVideoProcessor` to process videos for the model. " + ) validate_preprocess_arguments( do_rescale=do_rescale, diff --git a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py index 427a12d68ad..8c59606e4b6 100644 --- a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py @@ -20,7 +20,6 @@ import os from typing import List, Optional, Union from ...image_processing_utils import BatchFeature -from ...image_utils import VideoInput from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import ( AddedToken, @@ -31,6 +30,7 @@ from ...tokenization_utils_base import ( TruncationStrategy, ) from ...utils import TensorType, logging +from ...video_utils import VideoInput from ..auto import AutoTokenizer @@ -46,8 +46,8 @@ class InstructBlipVideoProcessor(ProcessorMixin): docstring of [`~InstructBlipVideoProcessor.__call__`] and [`~InstructBlipVideoProcessor.decode`] for more information. Args: - image_processor (`InstructBlipVideoImageProcessor`): - An instance of [`InstructBlipVideoImageProcessor`]. The image processor is a required input. + video_processor (`InstructBlipVideoVideoProcessor`): + An instance of [`InstructBlipVideoVideoProcessor`]. The video processor is a required input. tokenizer (`AutoTokenizer`): An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input. qformer_tokenizer (`AutoTokenizer`): @@ -56,20 +56,20 @@ class InstructBlipVideoProcessor(ProcessorMixin): Number of tokens used by the Qformer as queries, should be same as in model's config. """ - attributes = ["image_processor", "tokenizer", "qformer_tokenizer"] + attributes = ["video_processor", "tokenizer", "qformer_tokenizer"] valid_kwargs = ["num_query_tokens"] - image_processor_class = "InstructBlipVideoImageProcessor" + video_processor_class = "AutoVideoProcessor" tokenizer_class = "AutoTokenizer" qformer_tokenizer_class = "AutoTokenizer" - def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs): + def __init__(self, video_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs): if not hasattr(tokenizer, "video_token"): self.video_token = AddedToken("