mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
🔴 Video processors as a separate class (#35206)
* initial design * update all video processors * add tests * need to add qwen2-vl (not tested yet) * add qwen2-vl in auto map * fix copies * isort * resolve confilicts kinda * nit: * qwen2-vl is happy now * qwen2-5 happy * other models are happy * fix copies * fix tests * add docs * CI green now? * add more tests * even more changes + tests * doc builder fail * nit * Update src/transformers/models/auto/processing_auto.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * small update * imports correctly * dump, otherwise this is getting unmanagebale T-T * dump * update * another update * update * tests * move * modular * docs * test * another update * init * remove flakiness in tests * fixup * clean up and remove commented lines * docs * skip this one! * last fix after rebasing * run fixup * delete slow files * remove unnecessary tests + clean up a bit * small fixes * fix tests * more updates * docs * fix tests * update * style * fix qwen2-5-vl * fixup * fixup * unflatten batch when preparing * dump, come back soon * add docs and fix some tests * how to guard this with new dummies? * chat templates in qwen * address some comments * remove `Fast` suffix * fixup * oops should be imported from transforms * typo in requires dummies * new model added with video support * fixup once more * last fixup I hope * revert image processor name + comments * oh, this is why fetch test is failing * fix tests * fix more tests * fixup * add new models: internvl, smolvlm * update docs * imprt once * fix failing tests * do we need to guard it here again, why? * new model was added, update it * remove testcase from tester * fix tests * make style * not related CI fail, lets' just fix here * mark flaky for now, filas 15 out of 100 * style * maybe we can do this way? * don't download images in setup class --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
parent
716819b830
commit
a31fa218ad
@ -39,6 +39,8 @@
|
||||
title: Tokenizers
|
||||
- local: image_processors
|
||||
title: Image processors
|
||||
- local: video_processors
|
||||
title: Video processors
|
||||
- local: backbones
|
||||
title: Backbones
|
||||
- local: feature_extractors
|
||||
@ -362,7 +364,9 @@
|
||||
title: Feature Extractor
|
||||
- local: main_classes/image_processor
|
||||
title: Image Processor
|
||||
title: Main classes
|
||||
- local: main_classes/video_processor
|
||||
title: Video Processor
|
||||
title: Main Classes
|
||||
- sections:
|
||||
- sections:
|
||||
- local: model_doc/albert
|
||||
|
@ -16,7 +16,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Image processors
|
||||
|
||||
Image processors converts images into pixel values, tensors that represent image colors and size. The pixel values are inputs to a vision or video model. To ensure a pretrained model receives the correct input, an image processor can perform the following operations to make sure an image is exactly like the images a model was pretrained on.
|
||||
Image processors converts images into pixel values, tensors that represent image colors and size. The pixel values are inputs to a vision model. To ensure a pretrained model receives the correct input, an image processor can perform the following operations to make sure an image is exactly like the images a model was pretrained on.
|
||||
|
||||
- [`~BaseImageProcessor.center_crop`] to resize an image
|
||||
- [`~BaseImageProcessor.normalize`] or [`~BaseImageProcessor.rescale`] pixel values
|
||||
|
55
docs/source/en/main_classes/video_processor.md
Normal file
55
docs/source/en/main_classes/video_processor.md
Normal file
@ -0,0 +1,55 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
|
||||
# Video Processor
|
||||
|
||||
A **Video Processor** is a utility responsible for preparing input features for video models, as well as handling the post-processing of their outputs. It provides transformations such as resizing, normalization, and conversion into PyTorch.
|
||||
|
||||
The video processor extends the functionality of image processors by allowing Vision Large Language Models (VLMs) to handle videos with a distinct set of arguments compared to images. It serves as the bridge between raw video data and the model, ensuring that input features are optimized for the VLM.
|
||||
|
||||
When adding a new VLM or updating an existing one to enable distinct video preprocessing, saving and reloading the processor configuration will store the video related arguments in a dedicated file named `video_preprocessing_config.json`. Don't worry if you haven't upadted your VLM, the processor will try to load video related configurations from a file named `preprocessing_config.json`.
|
||||
|
||||
|
||||
### Usage Example
|
||||
Here's an example of how to load a video processor with [`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf) model:
|
||||
|
||||
```python
|
||||
from transformers import AutoVideoProcessor
|
||||
|
||||
processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
```
|
||||
|
||||
Currently, if using base image processor for videos, it processes video data by treating each frame as an individual image and applying transformations frame-by-frame. While functional, this approach is not highly efficient. Using `AutoVideoProcessor` allows us to take advantage of **fast video processors**, leveraging the [torchvision](https://pytorch.org/vision/stable/index.html) library. Fast processors handle the whole batch of videos at once, without iterating over each video or frame. These updates introduce GPU acceleration and significantly enhance processing speed, especially for tasks requiring high throughput.
|
||||
|
||||
Fast video processors are available for all models and are loaded by default when an `AutoVideoProcessor` is initialized. When using a fast video processor, you can also set the `device` argument to specify the device on which the processing should be done. By default, the processing is done on the same device as the inputs if the inputs are tensors, or on the CPU otherwise. For even more speed improvement, we can compile the processor when using 'cuda' as device.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers.video_utils import load_video
|
||||
from transformers import AutoVideoProcessor
|
||||
|
||||
video = load_video("video.mp4")
|
||||
processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device="cuda")
|
||||
processor = torch.compile(processor)
|
||||
processed_video = processor(video, return_tensors="pt")
|
||||
```
|
||||
|
||||
|
||||
## BaseVideoProcessor
|
||||
|
||||
[[autodoc]] video_processing_utils.BaseVideoProcessor
|
||||
|
@ -74,6 +74,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
|
||||
|
||||
[[autodoc]] AutoImageProcessor
|
||||
|
||||
## AutoVideoProcessor
|
||||
|
||||
[[autodoc]] AutoVideoProcessor
|
||||
|
||||
## AutoProcessor
|
||||
|
||||
[[autodoc]] AutoProcessor
|
||||
|
@ -58,6 +58,12 @@ The attributes can be obtained from model config, as `model.config.num_query_tok
|
||||
|
||||
[[autodoc]] InstructBlipVideoProcessor
|
||||
|
||||
|
||||
## InstructBlipVideoVideoProcessor
|
||||
|
||||
[[autodoc]] InstructBlipVideoVideoProcessor
|
||||
- preprocess
|
||||
|
||||
## InstructBlipVideoImageProcessor
|
||||
|
||||
[[autodoc]] InstructBlipVideoImageProcessor
|
||||
|
@ -353,3 +353,7 @@ This example showcases how to handle a batch of chat conversations with interlea
|
||||
## InternVLProcessor
|
||||
|
||||
[[autodoc]] InternVLProcessor
|
||||
|
||||
## InternVLVideoProcessor
|
||||
|
||||
[[autodoc]] InternVLVideoProcessor
|
||||
|
@ -262,6 +262,10 @@ model = LlavaNextVideoForConditionalGeneration.from_pretrained(
|
||||
|
||||
[[autodoc]] LlavaNextVideoImageProcessor
|
||||
|
||||
## LlavaNextVideoVideoProcessor
|
||||
|
||||
[[autodoc]] LlavaNextVideoVideoProcessor
|
||||
|
||||
## LlavaNextVideoModel
|
||||
|
||||
[[autodoc]] LlavaNextVideoModel
|
||||
|
@ -303,6 +303,7 @@ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
## LlavaOnevisionImageProcessor
|
||||
|
||||
[[autodoc]] LlavaOnevisionImageProcessor
|
||||
- preprocess
|
||||
|
||||
## LlavaOnevisionImageProcessorFast
|
||||
|
||||
@ -313,6 +314,10 @@ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
|
||||
[[autodoc]] LlavaOnevisionVideoProcessor
|
||||
|
||||
## LlavaOnevisionVideoProcessor
|
||||
|
||||
[[autodoc]] LlavaOnevisionVideoProcessor
|
||||
|
||||
## LlavaOnevisionModel
|
||||
|
||||
[[autodoc]] LlavaOnevisionModel
|
||||
|
@ -287,6 +287,11 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
[[autodoc]] Qwen2VLImageProcessor
|
||||
- preprocess
|
||||
|
||||
## Qwen2VLVideoProcessor
|
||||
|
||||
[[autodoc]] Qwen2VLVideoProcessor
|
||||
- preprocess
|
||||
|
||||
## Qwen2VLImageProcessorFast
|
||||
|
||||
[[autodoc]] Qwen2VLImageProcessorFast
|
||||
|
@ -197,6 +197,9 @@ print(generated_texts[0])
|
||||
[[autodoc]] SmolVLMImageProcessor
|
||||
- preprocess
|
||||
|
||||
## SmolVLMVideoProcessor
|
||||
[[autodoc]] SmolVLMVideoProcessor
|
||||
- preprocess
|
||||
|
||||
## SmolVLMProcessor
|
||||
[[autodoc]] SmolVLMProcessor
|
||||
|
@ -211,6 +211,11 @@ model = VideoLlavaForConditionalGeneration.from_pretrained(
|
||||
|
||||
[[autodoc]] VideoLlavaImageProcessor
|
||||
|
||||
|
||||
## VideoLlavaVideoProcessor
|
||||
|
||||
[[autodoc]] VideoLlavaVideoProcessor
|
||||
|
||||
## VideoLlavaProcessor
|
||||
|
||||
[[autodoc]] VideoLlavaProcessor
|
||||
|
49
docs/source/en/video_processors.md
Normal file
49
docs/source/en/video_processors.md
Normal file
@ -0,0 +1,49 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
|
||||
# Video Processor
|
||||
|
||||
A **Video Processor** is a utility responsible for preparing input features for video models, as well as handling the post-processing of their outputs. It provides transformations such as resizing, normalization, and conversion into PyTorch.
|
||||
|
||||
The video processor extends the functionality of image processors by allowing the models to handle videos with a distinct set of arguments compared to images. It serves as the bridge between raw video data and the model, ensuring that input features are optimized for the VLM.
|
||||
|
||||
Use [`~BaseVideoProcessor.from_pretrained`] to load a video processors configuration (image size, whether to normalize and rescale, etc.) from a video model on the Hugging Face [Hub](https://hf.co) or local directory. The configuration for each pretrained model should be saved in a [video_preprocessor_config.json] file but older models might have the config saved in [preprocessor_config.json](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf/blob/main/preprocessor_config.json) file. Note that the latter is less preferred and will be removed in the future.
|
||||
|
||||
|
||||
### Usage Example
|
||||
Here's an example of how to load a video processor with [`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf) model:
|
||||
|
||||
```python
|
||||
from transformers import AutoVideoProcessor
|
||||
|
||||
processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
```
|
||||
|
||||
Currently, if using base image processor for videos, it processes video data by treating each frame as an individual image and applying transformations frame-by-frame. While functional, this approach is not highly efficient. Using `AutoVideoProcessor` allows us to take advantage of **fast video processors**, leveraging the [torchvision](https://pytorch.org/vision/stable/index.html) library. Fast processors handle the whole batch of videos at once, without iterating over each video or frame. These updates introduce GPU acceleration and significantly enhance processing speed, especially for tasks requiring high throughput.
|
||||
|
||||
Fast video processors are available for all models and are loaded by default when an `AutoVideoProcessor` is initialized. When using a fast video processor, you can also set the `device` argument to specify the device on which the processing should be done. By default, the processing is done on the same device as the inputs if the inputs are tensors, or on the CPU otherwise. For even more speed improvement, we can compile the processor when using 'cuda' as device.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers.video_utils import load_video
|
||||
from transformers import AutoVideoProcessor
|
||||
|
||||
video = load_video("video.mp4")
|
||||
processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device="cuda")
|
||||
processor = torch.compile(processor)
|
||||
processed_video = processor(video, return_tensors="pt")
|
||||
```
|
@ -276,6 +276,7 @@ _import_structure = {
|
||||
"TorchAoConfig",
|
||||
"VptqConfig",
|
||||
],
|
||||
"video_utils": [],
|
||||
}
|
||||
|
||||
# tokenizers-backed objects
|
||||
@ -334,6 +335,7 @@ except OptionalDependencyNotAvailable:
|
||||
]
|
||||
else:
|
||||
_import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
|
||||
_import_structure["video_processing_utils"] = ["BaseVideoProcessor"]
|
||||
|
||||
# PyTorch-backed objects
|
||||
try:
|
||||
@ -809,6 +811,7 @@ if TYPE_CHECKING:
|
||||
from .utils.dummy_torchvision_objects import *
|
||||
else:
|
||||
from .image_processing_utils_fast import BaseImageProcessorFast
|
||||
from .video_processing_utils import BaseVideoProcessor
|
||||
|
||||
try:
|
||||
if not (is_torchvision_available() and is_timm_available()):
|
||||
|
@ -249,7 +249,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
|
||||
Image to resize.
|
||||
size (`SizeDict`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||
resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
|
||||
|
||||
Returns:
|
||||
|
@ -56,7 +56,9 @@ def to_channel_dimension_format(
|
||||
input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Converts `image` to the channel dimension format specified by `channel_dim`.
|
||||
Converts `image` to the channel dimension format specified by `channel_dim`. The input
|
||||
can have arbitrary number of leading dimensions. Only last three dimension will be permuted
|
||||
to format the `image`.
|
||||
|
||||
Args:
|
||||
image (`numpy.ndarray`):
|
||||
@ -80,9 +82,11 @@ def to_channel_dimension_format(
|
||||
return image
|
||||
|
||||
if target_channel_dim == ChannelDimension.FIRST:
|
||||
image = image.transpose((2, 0, 1))
|
||||
axes = list(range(image.ndim - 3)) + [image.ndim - 1, image.ndim - 3, image.ndim - 2]
|
||||
image = image.transpose(axes)
|
||||
elif target_channel_dim == ChannelDimension.LAST:
|
||||
image = image.transpose((1, 2, 0))
|
||||
axes = list(range(image.ndim - 3)) + [image.ndim - 2, image.ndim - 1, image.ndim - 3]
|
||||
image = image.transpose(axes)
|
||||
else:
|
||||
raise ValueError(f"Unsupported channel dimension format: {channel_dim}")
|
||||
|
||||
|
@ -15,11 +15,9 @@
|
||||
import base64
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from contextlib import redirect_stdout
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import Callable, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@ -27,9 +25,6 @@ from packaging import version
|
||||
|
||||
from .utils import (
|
||||
ExplicitEnum,
|
||||
is_av_available,
|
||||
is_cv2_available,
|
||||
is_decord_available,
|
||||
is_jax_tensor,
|
||||
is_numpy_array,
|
||||
is_tf_tensor,
|
||||
@ -37,7 +32,6 @@ from .utils import (
|
||||
is_torch_tensor,
|
||||
is_torchvision_available,
|
||||
is_vision_available,
|
||||
is_yt_dlp_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
to_numpy,
|
||||
@ -62,7 +56,6 @@ if is_vision_available():
|
||||
PILImageResampling = PIL.Image
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision import io as torchvision_io
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
pil_torch_interpolation_mapping = {
|
||||
@ -89,18 +82,6 @@ ImageInput = Union[
|
||||
] # noqa
|
||||
|
||||
|
||||
VideoInput = Union[
|
||||
list["PIL.Image.Image"],
|
||||
"np.ndarray",
|
||||
"torch.Tensor",
|
||||
list["np.ndarray"],
|
||||
list["torch.Tensor"],
|
||||
list[list["PIL.Image.Image"]],
|
||||
list[list["np.ndarray"]],
|
||||
list[list["torch.Tensor"]],
|
||||
] # noqa
|
||||
|
||||
|
||||
class ChannelDimension(ExplicitEnum):
|
||||
FIRST = "channels_first"
|
||||
LAST = "channels_last"
|
||||
@ -116,14 +97,6 @@ class AnnotionFormat(ExplicitEnum):
|
||||
COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoMetadata:
|
||||
total_num_frames: int
|
||||
fps: float
|
||||
duration: float
|
||||
video_backend: str
|
||||
|
||||
|
||||
AnnotationType = dict[str, Union[int, str, list[dict]]]
|
||||
|
||||
|
||||
@ -309,37 +282,6 @@ def make_nested_list_of_images(
|
||||
raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.")
|
||||
|
||||
|
||||
def make_batched_videos(videos) -> VideoInput:
|
||||
"""
|
||||
Ensure that the input is a list of videos.
|
||||
Args:
|
||||
videos (`VideoInput`):
|
||||
Video or videos to turn into a list of videos.
|
||||
Returns:
|
||||
list: A list of videos.
|
||||
"""
|
||||
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
||||
# case 1: nested batch of videos so we flatten it
|
||||
if not is_pil_image(videos[0][0]) and videos[0][0].ndim == 4:
|
||||
videos = [[video for batch_list in batched_videos for video in batch_list] for batched_videos in videos]
|
||||
# case 2: list of videos represented as list of video frames
|
||||
return videos
|
||||
|
||||
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
||||
if is_pil_image(videos[0]) or videos[0].ndim == 3:
|
||||
return [videos]
|
||||
elif videos[0].ndim == 4:
|
||||
return [list(video) for video in videos]
|
||||
|
||||
elif is_valid_image(videos):
|
||||
if is_pil_image(videos) or videos.ndim == 3:
|
||||
return [[videos]]
|
||||
elif videos.ndim == 4:
|
||||
return [list(videos)]
|
||||
|
||||
raise ValueError(f"Could not make batched video from {videos}")
|
||||
|
||||
|
||||
def to_numpy_array(img) -> np.ndarray:
|
||||
if not is_valid_image(img):
|
||||
raise ValueError(f"Invalid image type: {type(img)}")
|
||||
@ -371,6 +313,8 @@ def infer_channel_dimension_format(
|
||||
first_dim, last_dim = 0, 2
|
||||
elif image.ndim == 4:
|
||||
first_dim, last_dim = 1, 3
|
||||
elif image.ndim == 5:
|
||||
first_dim, last_dim = 2, 4
|
||||
else:
|
||||
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
|
||||
|
||||
@ -548,348 +492,6 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
|
||||
return image
|
||||
|
||||
|
||||
def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
|
||||
"""
|
||||
A default sampling function that replicates the logic used in get_uniform_frame_indices,
|
||||
while optionally handling `fps` if `num_frames` is not provided.
|
||||
|
||||
Args:
|
||||
metadata (`VideoMetadata`):
|
||||
`VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly.
|
||||
fps (`int`, *optional*):
|
||||
Desired frames per second. Takes priority over num_frames if both are provided.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: Array of frame indices to sample.
|
||||
"""
|
||||
total_num_frames = metadata.total_num_frames
|
||||
video_fps = metadata.fps
|
||||
|
||||
# If num_frames is not given but fps is, calculate num_frames from fps
|
||||
if num_frames is None and fps is not None:
|
||||
num_frames = int(total_num_frames / video_fps * fps)
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
|
||||
f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
|
||||
)
|
||||
|
||||
if num_frames is not None:
|
||||
indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
|
||||
else:
|
||||
indices = np.arange(0, total_num_frames, dtype=int)
|
||||
return indices
|
||||
|
||||
|
||||
def read_video_opencv(
|
||||
video_path: str,
|
||||
sample_indices_fn: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode a video using the OpenCV backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_indices_fn (`Callable`):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import cv2
|
||||
requires_backends(read_video_opencv, ["cv2"])
|
||||
import cv2
|
||||
|
||||
video = cv2.VideoCapture(video_path)
|
||||
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
video_fps = video.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv"
|
||||
)
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
index = 0
|
||||
frames = []
|
||||
while video.isOpened():
|
||||
success, frame = video.read()
|
||||
if not success:
|
||||
break
|
||||
if index in indices:
|
||||
height, width, channel = frame.shape
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frames.append(frame[0:height, 0:width, 0:channel])
|
||||
if success:
|
||||
index += 1
|
||||
if index >= total_num_frames:
|
||||
break
|
||||
|
||||
video.release()
|
||||
metadata.frames_indices = indices
|
||||
return np.stack(frames), metadata
|
||||
|
||||
|
||||
def read_video_decord(
|
||||
video_path: str,
|
||||
sample_indices_fn: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode a video using the Decord backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import from decord
|
||||
requires_backends(read_video_decord, ["decord"])
|
||||
from decord import VideoReader, cpu
|
||||
|
||||
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
|
||||
video_fps = vr.get_avg_fps()
|
||||
total_num_frames = len(vr)
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord"
|
||||
)
|
||||
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
frames = vr.get_batch(indices).asnumpy()
|
||||
metadata.frames_indices = indices
|
||||
return frames, metadata
|
||||
|
||||
|
||||
def read_video_pyav(
|
||||
video_path: str,
|
||||
sample_indices_fn: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode the video with PyAV decoder.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import av
|
||||
requires_backends(read_video_pyav, ["av"])
|
||||
import av
|
||||
|
||||
container = av.open(video_path)
|
||||
total_num_frames = container.streams.video[0].frames
|
||||
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav"
|
||||
)
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
frames = []
|
||||
container.seek(0)
|
||||
end_index = indices[-1]
|
||||
for i, frame in enumerate(container.decode(video=0)):
|
||||
if i > end_index:
|
||||
break
|
||||
if i >= 0 and i in indices:
|
||||
frames.append(frame)
|
||||
|
||||
video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||
metadata.frames_indices = indices
|
||||
return video, metadata
|
||||
|
||||
|
||||
def read_video_torchvision(
|
||||
video_path: str,
|
||||
sample_indices_fn: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode the video with torchvision decoder.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
video, _, info = torchvision_io.read_video(
|
||||
video_path,
|
||||
start_pts=0.0,
|
||||
end_pts=None,
|
||||
pts_unit="sec",
|
||||
output_format="THWC",
|
||||
)
|
||||
video_fps = info["video_fps"]
|
||||
total_num_frames = video.size(0)
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames),
|
||||
fps=float(video_fps),
|
||||
duration=float(duration),
|
||||
video_backend="torchvision",
|
||||
)
|
||||
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
video = video[indices].contiguous().numpy()
|
||||
metadata.frames_indices = indices
|
||||
return video, metadata
|
||||
|
||||
|
||||
VIDEO_DECODERS = {
|
||||
"decord": read_video_decord,
|
||||
"opencv": read_video_opencv,
|
||||
"pyav": read_video_pyav,
|
||||
"torchvision": read_video_torchvision,
|
||||
}
|
||||
|
||||
|
||||
def load_video(
|
||||
video: Union[str, "VideoInput"],
|
||||
num_frames: Optional[int] = None,
|
||||
fps: Optional[int] = None,
|
||||
backend: str = "opencv",
|
||||
sample_indices_fn: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
) -> np.array:
|
||||
"""
|
||||
Loads `video` to a numpy array.
|
||||
|
||||
Args:
|
||||
video (`str` or `VideoInput`):
|
||||
The video to convert to the numpy array format. Can be a link to video or local path.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
||||
fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
backend (`str`, *optional*, defaults to `"opencv"`):
|
||||
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
|
||||
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
|
||||
indices at which the video should be sampled. For example:
|
||||
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, Dict]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- Metadata dictionary.
|
||||
"""
|
||||
|
||||
# If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
|
||||
if fps is not None and num_frames is not None and sample_indices_fn is None:
|
||||
raise ValueError(
|
||||
"`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
|
||||
)
|
||||
|
||||
# If user didn't pass a sampling function, create one on the fly with default logic
|
||||
if sample_indices_fn is None:
|
||||
|
||||
def sample_indices_fn_func(metadata, **fn_kwargs):
|
||||
return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
|
||||
|
||||
sample_indices_fn = sample_indices_fn_func
|
||||
|
||||
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
|
||||
if not is_yt_dlp_available():
|
||||
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
|
||||
# Lazy import from yt_dlp
|
||||
requires_backends(load_video, ["yt_dlp"])
|
||||
from yt_dlp import YoutubeDL
|
||||
|
||||
buffer = BytesIO()
|
||||
with redirect_stdout(buffer), YoutubeDL() as f:
|
||||
f.download([video])
|
||||
bytes_obj = buffer.getvalue()
|
||||
file_obj = BytesIO(bytes_obj)
|
||||
elif video.startswith("http://") or video.startswith("https://"):
|
||||
file_obj = BytesIO(requests.get(video).content)
|
||||
elif os.path.isfile(video):
|
||||
file_obj = video
|
||||
elif is_valid_image(video) or (isinstance(video, (list, tuple)) and is_valid_image(video[0])):
|
||||
file_obj = None
|
||||
else:
|
||||
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
|
||||
|
||||
# can also load with decord, but not cv2/torchvision
|
||||
# both will fail in case of url links
|
||||
video_is_url = video.startswith("http://") or video.startswith("https://")
|
||||
if video_is_url and backend in ["opencv", "torchvision"]:
|
||||
raise ValueError(
|
||||
"If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend"
|
||||
)
|
||||
|
||||
if file_obj is None:
|
||||
return video
|
||||
|
||||
if (
|
||||
(not is_decord_available() and backend == "decord")
|
||||
or (not is_av_available() and backend == "pyav")
|
||||
or (not is_cv2_available() and backend == "opencv")
|
||||
or (not is_torchvision_available() and backend == "torchvision")
|
||||
):
|
||||
raise ImportError(
|
||||
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
|
||||
f"Make sure to install {backend} before loading the video."
|
||||
)
|
||||
|
||||
video_decoder = VIDEO_DECODERS[backend]
|
||||
video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
|
||||
return video, metadata
|
||||
|
||||
|
||||
def load_images(
|
||||
images: Union[list, tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
|
||||
) -> Union["PIL.Image.Image", list["PIL.Image.Image"], list[list["PIL.Image.Image"]]]:
|
||||
|
@ -27,6 +27,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_tf_auto import *
|
||||
from .processing_auto import *
|
||||
from .tokenization_auto import *
|
||||
from .video_processing_auto import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
@ -28,7 +28,14 @@ from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...image_processing_utils import ImageProcessingMixin
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
|
||||
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, cached_file, logging
|
||||
from ...utils import (
|
||||
FEATURE_EXTRACTOR_NAME,
|
||||
PROCESSOR_NAME,
|
||||
VIDEO_PROCESSOR_NAME,
|
||||
cached_file,
|
||||
logging,
|
||||
)
|
||||
from ...video_processing_utils import BaseVideoProcessor
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
@ -295,7 +302,24 @@ class AutoProcessor:
|
||||
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
||||
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
||||
|
||||
# If not found, let's check whether the processor class is saved in a feature extractor config
|
||||
# Saved as video processor
|
||||
if preprocessor_config_file is None:
|
||||
preprocessor_config_file = cached_file(
|
||||
pretrained_model_name_or_path, VIDEO_PROCESSOR_NAME, **cached_file_kwargs
|
||||
)
|
||||
if preprocessor_config_file is not None:
|
||||
config_dict, _ = BaseVideoProcessor.get_video_processor_dict(
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
processor_class = config_dict.get("processor_class", None)
|
||||
if "AutoProcessor" in config_dict.get("auto_map", {}):
|
||||
processor_auto_map = config_dict["auto_map"]["AutoProcessor"]
|
||||
|
||||
# Saved as feature extractor
|
||||
if preprocessor_config_file is None:
|
||||
preprocessor_config_file = cached_file(
|
||||
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
|
||||
)
|
||||
if preprocessor_config_file is not None and processor_class is None:
|
||||
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
|
384
src/transformers/models/auto/video_processing_auto.py
Normal file
384
src/transformers/models/auto/video_processing_auto.py
Normal file
@ -0,0 +1,384 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""AutoVideoProcessor class."""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
|
||||
# Build the list of all video processors
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
from ...utils import (
|
||||
CONFIG_NAME,
|
||||
VIDEO_PROCESSOR_NAME,
|
||||
cached_file,
|
||||
is_torchvision_available,
|
||||
logging,
|
||||
)
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_processing_utils import BaseVideoProcessor
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
AutoConfig,
|
||||
model_type_to_module_name,
|
||||
replace_list_option_in_docstrings,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# This significantly improves completion suggestion performance when
|
||||
# the transformers package is used with Microsoft's Pylance language server.
|
||||
VIDEO_PROCESSOR_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict()
|
||||
else:
|
||||
VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("instructblipvideo", "InstructBlipVideoVideoProcessor"),
|
||||
("llava_next_video", "LlavaNextVideoVideoProcessor"),
|
||||
("llava_onevision", "LlavaOnevisionVideoProcessor"),
|
||||
("qwen2_5_vl", "Qwen2_5_VLVideoProcessor"),
|
||||
("qwen2_vl", "Qwen2VLVideoProcessor"),
|
||||
("video_llava", "VideoLlavaVideoProcessor"),
|
||||
]
|
||||
)
|
||||
|
||||
for model_type, video_processors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
|
||||
fast_video_processor_class = video_processors
|
||||
|
||||
# If the torchvision is not available, we set it to None
|
||||
if not is_torchvision_available():
|
||||
fast_video_processor_class = None
|
||||
|
||||
VIDEO_PROCESSOR_MAPPING_NAMES[model_type] = fast_video_processor_class
|
||||
|
||||
VIDEO_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, VIDEO_PROCESSOR_MAPPING_NAMES)
|
||||
|
||||
|
||||
def video_processor_class_from_name(class_name: str):
|
||||
for module_name, extractors in VIDEO_PROCESSOR_MAPPING_NAMES.items():
|
||||
if class_name in extractors:
|
||||
module_name = model_type_to_module_name(module_name)
|
||||
|
||||
module = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
for _, extractor in VIDEO_PROCESSOR_MAPPING._extra_content.items():
|
||||
if getattr(extractor, "__name__", None) == class_name:
|
||||
return extractor
|
||||
|
||||
# We did not find the class, but maybe it's because a dep is missing. In that case, the class will be in the main
|
||||
# init and we return the proper dummy to get an appropriate error message.
|
||||
main_module = importlib.import_module("transformers")
|
||||
if hasattr(main_module, class_name):
|
||||
return getattr(main_module, class_name)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_video_processor_config(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: Optional[bool] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads the video processor configuration from a pretrained model video processor configuration.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
||||
huggingface.co.
|
||||
- a path to a *directory* containing a configuration file saved using the
|
||||
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
||||
|
||||
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
||||
cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
||||
exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||
Will be removed in v5 of Transformers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the video processor configuration from local files.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `token=True` is required when you want to use a private model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`Dict`: The configuration of the video processor.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Download configuration from huggingface.co and cache.
|
||||
video_processor_config = get_video_processor_config("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
# This model does not have a video processor config so the result will be an empty dict.
|
||||
video_processor_config = get_video_processor_config("FacebookAI/xlm-roberta-base")
|
||||
|
||||
# Save a pretrained video processor locally and you can reload its config
|
||||
from transformers import AutoVideoProcessor
|
||||
|
||||
video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
video_processor.save_pretrained("video-processor-test")
|
||||
video_processor = get_video_processor_config("video-processor-test")
|
||||
```"""
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if token is not None:
|
||||
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
||||
token = use_auth_token
|
||||
|
||||
resolved_config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
VIDEO_PROCESSOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
logger.info(
|
||||
"Could not locate the video processor configuration file, will try to use the model config instead."
|
||||
)
|
||||
return {}
|
||||
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
return json.load(reader)
|
||||
|
||||
|
||||
@requires(backends=("vision", "torchvision"))
|
||||
class AutoVideoProcessor:
|
||||
r"""
|
||||
This is a generic video processor class that will be instantiated as one of the video processor classes of the
|
||||
library when created with the [`AutoVideoProcessor.from_pretrained`] class method.
|
||||
|
||||
This class cannot be instantiated directly using `__init__()` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
"AutoVideoProcessor is designed to be instantiated "
|
||||
"using the `AutoVideoProcessor.from_pretrained(pretrained_model_name_or_path)` method."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings(VIDEO_PROCESSOR_MAPPING_NAMES)
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
r"""
|
||||
Instantiate one of the video processor classes of the library from a pretrained model vocabulary.
|
||||
|
||||
The video processor class to instantiate is selected based on the `model_type` property of the config object
|
||||
(either passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's
|
||||
missing, by falling back to using pattern matching on `pretrained_model_name_or_path`:
|
||||
|
||||
List options
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the *model id* of a pretrained video_processor hosted inside a model repo on
|
||||
huggingface.co.
|
||||
- a path to a *directory* containing a video processor file saved using the
|
||||
[`~video_processing_utils.BaseVideoProcessor.save_pretrained`] method, e.g.,
|
||||
`./my_model_directory/`.
|
||||
- a path or url to a saved video processor JSON *file*, e.g.,
|
||||
`./my_model_directory/preprocessor_config.json`.
|
||||
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model video processor should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the video processor files and override the cached versions if
|
||||
they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||
Will be removed in v5 of Transformers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
If `False`, then this function returns just the final video processor object. If `True`, then this
|
||||
functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
|
||||
consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of
|
||||
`kwargs` which has not been used to update `video_processor` and is otherwise ignored.
|
||||
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
|
||||
should only be set to `True` for repositories you trust and in which you have read the code, as it will
|
||||
execute code present on the Hub on your local machine.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
The values in kwargs of any keys which are video processor attributes will be used to override the
|
||||
loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is
|
||||
controlled by the `return_unused_kwargs` keyword parameter.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `token=True` is required when you want to use a private model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoVideoProcessor
|
||||
|
||||
>>> # Download video processor from huggingface.co and cache.
|
||||
>>> video_processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
|
||||
>>> # If video processor files are in a directory (e.g. video processor was saved using *save_pretrained('./test/saved_model/')*)
|
||||
>>> # video_processor = AutoVideoProcessor.from_pretrained("./test/saved_model/")
|
||||
```"""
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if kwargs.get("token", None) is not None:
|
||||
raise ValueError(
|
||||
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
||||
)
|
||||
kwargs["token"] = use_auth_token
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
kwargs["_from_auto"] = True
|
||||
|
||||
config_dict, _ = BaseVideoProcessor.get_video_processor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
video_processor_class = config_dict.get("video_processor_type", None)
|
||||
video_processor_auto_map = None
|
||||
if "AutoVideoProcessor" in config_dict.get("auto_map", {}):
|
||||
video_processor_auto_map = config_dict["auto_map"]["AutoVideoProcessor"]
|
||||
|
||||
# If we still don't have the video processor class, check if we're loading from a previous feature extractor config
|
||||
# and if so, infer the video processor class from there.
|
||||
if video_processor_class is None and video_processor_auto_map is None:
|
||||
feature_extractor_class = config_dict.pop("feature_extractor_type", None)
|
||||
if feature_extractor_class is not None:
|
||||
video_processor_class = feature_extractor_class.replace("FeatureExtractor", "VideoProcessor")
|
||||
if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
|
||||
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
|
||||
video_processor_auto_map = feature_extractor_auto_map.replace("FeatureExtractor", "VideoProcessor")
|
||||
|
||||
# If we don't find the video processor class in the video processor config, let's try the model config.
|
||||
if video_processor_class is None and video_processor_auto_map is None:
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
||||
)
|
||||
# It could be in `config.video_processor_type``
|
||||
video_processor_class = getattr(config, "video_processor_type", None)
|
||||
if hasattr(config, "auto_map") and "AutoVideoProcessor" in config.auto_map:
|
||||
video_processor_auto_map = config.auto_map["AutoVideoProcessor"]
|
||||
|
||||
if video_processor_class is not None:
|
||||
video_processor_class = video_processor_class_from_name(video_processor_class)
|
||||
|
||||
has_remote_code = video_processor_auto_map is not None
|
||||
has_local_code = video_processor_class is not None or type(config) in VIDEO_PROCESSOR_MAPPING
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
|
||||
)
|
||||
|
||||
if has_remote_code and trust_remote_code:
|
||||
class_ref = video_processor_auto_map
|
||||
video_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
||||
_ = kwargs.pop("code_revision", None)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
video_processor_class.register_for_auto_class()
|
||||
return video_processor_class.from_dict(config_dict, **kwargs)
|
||||
elif video_processor_class is not None:
|
||||
return video_processor_class.from_dict(config_dict, **kwargs)
|
||||
# Last try: we use the VIDEO_PROCESSOR_MAPPING.
|
||||
elif type(config) in VIDEO_PROCESSOR_MAPPING:
|
||||
video_processor_class = VIDEO_PROCESSOR_MAPPING[type(config)]
|
||||
|
||||
if video_processor_class is not None:
|
||||
return video_processor_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"This video processor cannot be instantiated. Please make sure you have `torchvision` installed."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognized video processor in {pretrained_model_name_or_path}. Should have a "
|
||||
f"`video_processor_type` key in its {VIDEO_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
|
||||
f"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in VIDEO_PROCESSOR_MAPPING_NAMES.keys())}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register(
|
||||
config_class,
|
||||
video_processor_class,
|
||||
exist_ok=False,
|
||||
):
|
||||
"""
|
||||
Register a new video processor for this class.
|
||||
|
||||
Args:
|
||||
config_class ([`PretrainedConfig`]):
|
||||
The configuration corresponding to the model to register.
|
||||
video_processor_class ([`BaseVideoProcessor`]):
|
||||
The video processor to register.
|
||||
"""
|
||||
VIDEO_PROCESSOR_MAPPING.register(config_class, video_processor_class, exist_ok=exist_ok)
|
||||
|
||||
|
||||
__all__ = ["VIDEO_PROCESSOR_MAPPING", "AutoVideoProcessor"]
|
@ -27,7 +27,6 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
VideoInput,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
@ -166,7 +165,7 @@ class Emu3ImageProcessor(BaseImageProcessor):
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: Union[ImageInput, VideoInput],
|
||||
images: ImageInput,
|
||||
do_resize: Optional[bool] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
|
@ -22,6 +22,7 @@ if TYPE_CHECKING:
|
||||
from .image_processing_instructblipvideo import *
|
||||
from .modeling_instructblipvideo import *
|
||||
from .processing_instructblipvideo import *
|
||||
from .video_processing_instructblipvideo import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
@ -29,20 +29,20 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
VideoInput,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_batched_videos,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, filter_out_non_signature_kwargs, logging
|
||||
from ...video_utils import VideoInput, make_batched_videos
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# TODO (raushan): processor can be removed after v5 release. Kept for backwards compatibility
|
||||
# Copied from transformers.models.blip.image_processing_blip.BlipImageProcessor with Blip->InstructBlipVideo, BLIP->InstructBLIPVideo
|
||||
class InstructBlipVideoImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
@ -236,6 +236,10 @@ class InstructBlipVideoImageProcessor(BaseImageProcessor):
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
|
||||
videos = make_batched_videos(images)
|
||||
logger.warning(
|
||||
"`InstructBlipVideoImageProcessor` is deprecated and will be removed in v5.0. "
|
||||
"We recommend to load an instance of `InstructBlipVideoVideoProcessor` to process videos for the model. "
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
|
@ -20,7 +20,6 @@ import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import VideoInput
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import (
|
||||
AddedToken,
|
||||
@ -31,6 +30,7 @@ from ...tokenization_utils_base import (
|
||||
TruncationStrategy,
|
||||
)
|
||||
from ...utils import TensorType, logging
|
||||
from ...video_utils import VideoInput
|
||||
from ..auto import AutoTokenizer
|
||||
|
||||
|
||||
@ -46,8 +46,8 @@ class InstructBlipVideoProcessor(ProcessorMixin):
|
||||
docstring of [`~InstructBlipVideoProcessor.__call__`] and [`~InstructBlipVideoProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor (`InstructBlipVideoImageProcessor`):
|
||||
An instance of [`InstructBlipVideoImageProcessor`]. The image processor is a required input.
|
||||
video_processor (`InstructBlipVideoVideoProcessor`):
|
||||
An instance of [`InstructBlipVideoVideoProcessor`]. The video processor is a required input.
|
||||
tokenizer (`AutoTokenizer`):
|
||||
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
|
||||
qformer_tokenizer (`AutoTokenizer`):
|
||||
@ -56,20 +56,20 @@ class InstructBlipVideoProcessor(ProcessorMixin):
|
||||
Number of tokens used by the Qformer as queries, should be same as in model's config.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer", "qformer_tokenizer"]
|
||||
attributes = ["video_processor", "tokenizer", "qformer_tokenizer"]
|
||||
valid_kwargs = ["num_query_tokens"]
|
||||
image_processor_class = "InstructBlipVideoImageProcessor"
|
||||
video_processor_class = "AutoVideoProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
qformer_tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
|
||||
def __init__(self, video_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
|
||||
if not hasattr(tokenizer, "video_token"):
|
||||
self.video_token = AddedToken("<video>", normalized=False, special=True)
|
||||
tokenizer.add_tokens([self.video_token], special_tokens=True)
|
||||
else:
|
||||
self.video_token = tokenizer.video_token
|
||||
self.num_query_tokens = num_query_tokens
|
||||
super().__init__(image_processor, tokenizer, qformer_tokenizer)
|
||||
super().__init__(video_processor, tokenizer, qformer_tokenizer)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -176,7 +176,7 @@ class InstructBlipVideoProcessor(ProcessorMixin):
|
||||
encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
|
||||
|
||||
if images is not None:
|
||||
image_encoding = self.image_processor(images, return_tensors=return_tensors)
|
||||
image_encoding = self.video_processor(images, return_tensors=return_tensors)
|
||||
encoding.update(image_encoding)
|
||||
|
||||
return encoding
|
||||
|
@ -0,0 +1,125 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Video processor class for InstructBLIPVideo
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
SizeDict,
|
||||
)
|
||||
from ...processing_utils import Unpack, VideosKwargs
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_processing_utils import BaseVideoProcessor
|
||||
from ...video_utils import group_videos_by_shape, reorder_videos
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class InstructBlipVideoVideoProcessorInitKwargs(VideosKwargs): ...
|
||||
|
||||
|
||||
@requires(backends=("torchvision",))
|
||||
class InstructBlipVideoVideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
size = {"height": 384, "width": 384}
|
||||
default_to_square = True
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
valid_kwargs = InstructBlipVideoVideoProcessorInitKwargs
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[InstructBlipVideoVideoProcessorInitKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
videos: List["torch.Tensor"],
|
||||
do_convert_rgb: bool,
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
size_divisor: Optional[int],
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
do_pad: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
# Group videos by size for batched resizing
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||
resized_videos_grouped = {}
|
||||
for shape, stacked_videos in grouped_videos.items():
|
||||
if do_convert_rgb:
|
||||
stacked_videos = self.convert_to_rgb(stacked_videos)
|
||||
if do_resize:
|
||||
stacked_videos = self.resize(
|
||||
stacked_videos, size=size, size_divisor=size_divisor, interpolation=interpolation
|
||||
)
|
||||
resized_videos_grouped[shape] = stacked_videos
|
||||
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
|
||||
|
||||
# Group videos by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns videos with different sizes
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
|
||||
processed_videos_grouped = {}
|
||||
for shape, stacked_videos in grouped_videos.items():
|
||||
if do_center_crop:
|
||||
stacked_videos = self.center_crop(stacked_videos, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_videos = self.rescale_and_normalize(
|
||||
stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_videos_grouped[shape] = stacked_videos
|
||||
|
||||
processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
|
||||
processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos
|
||||
|
||||
return BatchFeature(data={"pixel_values": processed_videos}, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["InstructBlipVideoVideoProcessor"]
|
@ -29,13 +29,10 @@ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import (
|
||||
ImageInput,
|
||||
VideoInput,
|
||||
VideoMetadata,
|
||||
concatenate_list,
|
||||
load_video,
|
||||
make_batched_videos,
|
||||
make_flat_list_of_images,
|
||||
)
|
||||
from ...video_utils import VideoInput, VideoMetadata, load_video, make_batched_videos
|
||||
|
||||
|
||||
class InternVLImagesKwargs(ImagesKwargs, total=False):
|
||||
@ -53,9 +50,7 @@ class InternVLProcessorKwargs(ProcessingKwargs, total=False):
|
||||
"images_kwargs": {
|
||||
"crop_to_patches": True,
|
||||
},
|
||||
"videos_kwargs": {
|
||||
"crop_to_patches": False,
|
||||
},
|
||||
"videos_kwargs": {},
|
||||
}
|
||||
|
||||
|
||||
@ -69,6 +64,8 @@ class InternVLProcessor(ProcessorMixin):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
video_processor ([`AutoVideoProcessor`], *optional*):
|
||||
The video processor is a required input.
|
||||
image_seq_length (`int`, *optional*, defaults to 256):
|
||||
The number of image token to use per image patch. it should be set so that:
|
||||
image_seq_length = (config.image_size // config.patch_size) ** 2 * (config.scale_factor**2)
|
||||
@ -76,18 +73,20 @@ class InternVLProcessor(ProcessorMixin):
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
attributes = ["image_processor", "tokenizer", "video_processor"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"image_seq_length",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
video_processor_class = "AutoVideoProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
video_processor=None,
|
||||
image_seq_length: int = 256,
|
||||
chat_template=None,
|
||||
**kwargs,
|
||||
@ -99,7 +98,7 @@ class InternVLProcessor(ProcessorMixin):
|
||||
self.video_token = tokenizer.video_token
|
||||
self.image_token_id = tokenizer.context_image_token_id
|
||||
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
|
||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
|
||||
|
||||
def _insert_media_placeholders(
|
||||
self,
|
||||
@ -237,10 +236,9 @@ class InternVLProcessor(ProcessorMixin):
|
||||
videos = make_batched_videos(videos)
|
||||
num_frames_per_video = [len(video) for video in videos]
|
||||
video_patch_indices = np.cumsum(num_frames_per_video)
|
||||
output_kwargs["images_kwargs"]["crop_to_patches"] = False
|
||||
video_inputs = self.image_processor(images=videos, **output_kwargs["videos_kwargs"])
|
||||
video_num_patches = video_inputs.pop("num_patches")
|
||||
video_pixel_values = video_inputs.pop("pixel_values")
|
||||
video_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||
video_num_patches = [1 for frames in num_frames_per_video for _ in range(frames)]
|
||||
video_pixel_values = video_inputs.pop("pixel_values_videos").flatten(0, 1)
|
||||
video_num_patches_indices = np.cumsum(video_num_patches)
|
||||
|
||||
if images is not None or videos is not None:
|
||||
|
@ -0,0 +1,55 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Video processor class for InternVL."""
|
||||
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
)
|
||||
from ...processing_utils import Unpack, VideosKwargs
|
||||
from ...utils import (
|
||||
is_vision_available,
|
||||
)
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_processing_utils import (
|
||||
BaseVideoProcessor,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
|
||||
class InternVLVideoProcessorInitKwargs(VideosKwargs): ...
|
||||
|
||||
|
||||
@requires(backends=("torchvision",))
|
||||
class InternVLVideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
size = {"height": 384, "width": 384}
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
valid_kwargs = InternVLVideoProcessorInitKwargs
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[InternVLVideoProcessorInitKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["InternVLVideoProcessor"]
|
@ -31,15 +31,14 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
VideoInput,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_batched_videos,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, logging
|
||||
from ...video_utils import VideoInput, make_batched_videos
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -358,6 +357,10 @@ class LlavaNextVideoImageProcessor(BaseImageProcessor):
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
images = make_batched_videos(images)
|
||||
logger.warning(
|
||||
"`LlavaNextVideoImageProcessor` is deprecated and will be removed in v5.0. "
|
||||
"We recommend to load an instance of `LlavaNextVideoVideoProcessor` to process videos for the model. "
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
|
@ -22,10 +22,11 @@ import numpy as np
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_processing_utils import select_best_resolution
|
||||
from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array
|
||||
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
from ...video_utils import VideoInput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -52,7 +53,7 @@ class LlavaNextVideoProcessor(ProcessorMixin):
|
||||
[`LlamaTokenizerFast`]. See the [`~LlavaNextVideoProcessor.__call__`] and [`~LlavaNextVideoProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
video_processor ([`LlavaNextVideoImageProcessor`], *optional*):
|
||||
video_processor ([`LlavaNextVideoVideoProcessor`], *optional*):
|
||||
The video processor is a required input.
|
||||
image_processor ([`LlavaNextImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
@ -86,7 +87,7 @@ class LlavaNextVideoProcessor(ProcessorMixin):
|
||||
"num_additional_image_tokens",
|
||||
]
|
||||
image_processor_class = ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")
|
||||
video_processor_class = "LlavaNextVideoImageProcessor"
|
||||
video_processor_class = "AutoVideoProcessor"
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
|
||||
def __init__(
|
||||
|
@ -0,0 +1,56 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Video processor class for LLaVa-NeXT-Video."""
|
||||
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
)
|
||||
from ...processing_utils import Unpack, VideosKwargs
|
||||
from ...utils import is_vision_available
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_processing_utils import (
|
||||
BaseVideoProcessor,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
|
||||
class LlavaNextVideoFastVideoProcessorInitKwargs(VideosKwargs): ...
|
||||
|
||||
|
||||
@requires(backends=("torchvision",))
|
||||
class LlavaNextVideoVideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
size = {"shortest_edge": 224}
|
||||
default_to_square = False
|
||||
crop_size = {"height": 224, "width": 224}
|
||||
do_resize = True
|
||||
do_center_crop = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
valid_kwargs = LlavaNextVideoFastVideoProcessorInitKwargs
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[LlavaNextVideoFastVideoProcessorInitKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["LlavaNextVideoVideoProcessor"]
|
@ -17,18 +17,17 @@ Processor class for LLaVa-Onevision.
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Iterable, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_processing_utils import select_best_resolution
|
||||
from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array
|
||||
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
from ..auto import AutoImageProcessor
|
||||
from ...video_utils import VideoInput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -85,7 +84,7 @@ class LlavaOnevisionProcessor(ProcessorMixin):
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
video_processor_class = "LlavaOnevisionVideoProcessor"
|
||||
video_processor_class = "AutoVideoProcessor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -300,49 +299,5 @@ class LlavaOnevisionProcessor(ProcessorMixin):
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
# override to save video-config in a separate config file
|
||||
def save_pretrained(self, save_directory, **kwargs):
|
||||
if os.path.isfile(save_directory):
|
||||
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
video_processor_path = os.path.join(save_directory, "video_processor")
|
||||
self.video_processor.save_pretrained(video_processor_path)
|
||||
|
||||
video_processor_present = "video_processor" in self.attributes
|
||||
try:
|
||||
if video_processor_present:
|
||||
self.attributes.remove("video_processor")
|
||||
|
||||
outputs = super().save_pretrained(save_directory, **kwargs)
|
||||
finally:
|
||||
if video_processor_present:
|
||||
self.attributes += ["video_processor"]
|
||||
return outputs
|
||||
|
||||
# override to load video-config from a separate config file
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
|
||||
if isinstance(processor, tuple):
|
||||
processor = processor[0]
|
||||
|
||||
try:
|
||||
video_processor = AutoImageProcessor.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="video_processor"
|
||||
)
|
||||
processor.video_processor = video_processor
|
||||
except EnvironmentError:
|
||||
# this means users are using prev version of saved processor where we had only one preprocessor_config.json
|
||||
# for loading back that should work and load a LlavaOnevisionVideoProcessor class
|
||||
logger.info(
|
||||
"You are loading `LlavaOnevisionProcessor` but the indicated `path` doesn't contain a folder called "
|
||||
"`video_processor`. It is strongly recommended to load and save the processor again so the video processor is saved "
|
||||
"in a separate config."
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
__all__ = ["LlavaOnevisionProcessor"]
|
||||
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -14,309 +14,44 @@
|
||||
# limitations under the License.
|
||||
"""Video processor class for LLaVa-Onevision."""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
convert_to_rgb,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
VideoInput,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_batched_videos,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, logging
|
||||
from ...processing_utils import Unpack, VideosKwargs
|
||||
from ...utils import is_vision_available
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_processing_utils import (
|
||||
BaseVideoProcessor,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
|
||||
@requires(backends=("vision",))
|
||||
class LlavaOnevisionVideoProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a LLaVa-Onevisino-Video video processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.
|
||||
class LlavaOnevisionFastVideoProcessorInitKwargs(VideosKwargs): ...
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||
`do_resize` in the `preprocess` method.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
|
||||
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
|
||||
method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||
the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
|
||||
@requires(backends=("torchvision",))
|
||||
class LlavaOnevisionVideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
size = {"height": 384, "width": 384}
|
||||
rescale_factor = 1 / 255
|
||||
default_to_square = False
|
||||
crop_size = None
|
||||
do_resize = True
|
||||
do_center_crop = None
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
valid_kwargs = LlavaOnevisionFastVideoProcessorInitKwargs
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
def __init__(self, **kwargs: Unpack[LlavaOnevisionFastVideoProcessorInitKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 384, "width": 384}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> list[np.ndarray]:
|
||||
"""
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Batch of frames (one video) to preprocess. Expects a batch of frames with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if do_rescale and is_scaled_image(images[0]):
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled videos. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [
|
||||
resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
return images
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
videos = make_batched_videos(videos)
|
||||
|
||||
if not valid_images(videos[0]):
|
||||
raise ValueError(
|
||||
"Invalid video type. Must be a list consisting of PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
size_tuple = (
|
||||
(size["height"], size["width"])
|
||||
if "height" in size and "width" in size
|
||||
else (size["shortest_edge"], size["shortest_edge"])
|
||||
)
|
||||
|
||||
pixel_values = [
|
||||
self._preprocess(
|
||||
video,
|
||||
do_resize=do_resize,
|
||||
size=size_tuple,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for video in videos
|
||||
]
|
||||
|
||||
return BatchFeature(
|
||||
data={"pixel_values_videos": pixel_values},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["LlavaOnevisionVideoProcessor"]
|
||||
|
@ -24,9 +24,10 @@ from typing import List, Optional, Union
|
||||
import numpy as np
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, VideoInput, make_batched_videos
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
|
||||
from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput
|
||||
from ...video_utils import VideoInput, make_batched_videos
|
||||
|
||||
|
||||
class Qwen2_5_OmniVideosKwargs(VideosKwargs):
|
||||
@ -81,6 +82,8 @@ class Qwen2_5OmniProcessor(ProcessorMixin):
|
||||
Args:
|
||||
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
||||
The image processor.
|
||||
video_processor ([`Qwen2VLVideoProcessor`], *optional*):
|
||||
The video processor.
|
||||
feature_extractor ([`WhisperFeatureExtractor`], *optional*):
|
||||
The audio feature extractor.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
@ -89,14 +92,17 @@ class Qwen2_5OmniProcessor(ProcessorMixin):
|
||||
The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "feature_extractor", "tokenizer"]
|
||||
attributes = ["image_processor", "video_processor", "feature_extractor", "tokenizer"]
|
||||
image_processor_class = "Qwen2VLImageProcessor"
|
||||
video_processor_class = "Qwen2VLVideoProcessor"
|
||||
feature_extractor_class = "WhisperFeatureExtractor"
|
||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||
valid_kwargs = ["chat_template"]
|
||||
|
||||
def __init__(self, image_processor=None, feature_extractor=None, tokenizer=None, chat_template=None):
|
||||
super().__init__(image_processor, feature_extractor, tokenizer, chat_template=chat_template)
|
||||
def __init__(
|
||||
self, image_processor=None, video_processor=None, feature_extractor=None, tokenizer=None, chat_template=None
|
||||
):
|
||||
super().__init__(image_processor, video_processor, feature_extractor, tokenizer, chat_template=chat_template)
|
||||
self.image_token = self.tokenizer.image_token
|
||||
self.audio_token = self.tokenizer.audio_token
|
||||
self.video_token = self.tokenizer.video_token
|
||||
@ -175,10 +181,10 @@ class Qwen2_5OmniProcessor(ProcessorMixin):
|
||||
|
||||
if videos is not None:
|
||||
videos = make_batched_videos(videos)
|
||||
videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["videos_kwargs"])
|
||||
videos_inputs = self.video_processor(images=None, videos=videos, **output_kwargs["videos_kwargs"])
|
||||
fps = [fps] * len(videos)
|
||||
videos_inputs["video_second_per_grid"] = [
|
||||
self.image_processor.temporal_patch_size / fps[i] for i in range(len(fps))
|
||||
self.video_processor.temporal_patch_size / fps[i] for i in range(len(fps))
|
||||
]
|
||||
video_grid_thw = iter(videos_inputs["video_grid_thw"])
|
||||
video_second_per_grid = iter(videos_inputs["video_second_per_grid"])
|
||||
@ -220,7 +226,8 @@ class Qwen2_5OmniProcessor(ProcessorMixin):
|
||||
seconds_per_chunk,
|
||||
):
|
||||
# Extend mm token length
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
merge_length_image = self.image_processor.merge_size**2
|
||||
merge_length_video = self.video_processor.merge_size**2
|
||||
|
||||
processed_text = []
|
||||
for sample in text:
|
||||
@ -234,17 +241,17 @@ class Qwen2_5OmniProcessor(ProcessorMixin):
|
||||
if special_token == self.audio_token:
|
||||
sample = sample.replace(self.audio_token, "<|audio_placeholder|>" * next(audio_lengths), 1)
|
||||
elif special_token == self.image_token:
|
||||
image_seq_length = next(image_grid_thw).prod() // merge_length
|
||||
image_seq_length = next(image_grid_thw).prod() // merge_length_image
|
||||
sample = sample.replace(self.image_token, "<|image_placeholder|>" * image_seq_length, 1)
|
||||
elif special_token == self.video_token:
|
||||
if not use_audio_in_video:
|
||||
video_seq_length = next(video_grid_thw).prod() // merge_length
|
||||
video_seq_length = next(video_grid_thw).prod() // merge_length_video
|
||||
sample = sample.replace(self.video_token, "<|video_placeholder|>" * video_seq_length, 1)
|
||||
else:
|
||||
audio_token_indices = np.arange(next(audio_lengths))
|
||||
curr_video_grid_thw = next(video_grid_thw)
|
||||
height = curr_video_grid_thw[1] // self.image_processor.merge_size
|
||||
width = curr_video_grid_thw[2] // self.image_processor.merge_size
|
||||
height = curr_video_grid_thw[1] // self.video_processor.merge_size
|
||||
width = curr_video_grid_thw[2] // self.video_processor.merge_size
|
||||
video_token_indices = np.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1)
|
||||
video_token_indices = np.broadcast_to(
|
||||
video_token_indices, (video_token_indices.shape[0], height, width)
|
||||
|
@ -46,11 +46,12 @@ from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLImagesKwargs
|
||||
from ...activations import ACT2FN
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, VideoInput
|
||||
from ...image_utils import ImageInput
|
||||
from ...modeling_flash_attention_utils import is_flash_attn_available
|
||||
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
from ...video_utils import VideoInput
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
@ -928,6 +929,8 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
video_processor ([`Qwen2_5_VLVideoProcessor`], *optional*):
|
||||
The video processor is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
@ -990,37 +993,32 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor):
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
image_inputs = videos_inputs = {}
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
|
||||
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_grid_thw = None
|
||||
|
||||
if videos is not None:
|
||||
videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"])
|
||||
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
|
||||
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
|
||||
if isinstance(fps, (int, float)):
|
||||
second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw)
|
||||
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
|
||||
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
|
||||
second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps]
|
||||
second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
|
||||
)
|
||||
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
|
||||
|
||||
else:
|
||||
videos_inputs = {}
|
||||
video_grid_thw = None
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text = text.copy() # below lines change text in-place
|
||||
if image_grid_thw is not None:
|
||||
if images is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
@ -1030,8 +1028,8 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor):
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||
|
||||
if video_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
if videos is not None:
|
||||
merge_length = self.video_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while self.video_token in text[i]:
|
||||
|
@ -26,9 +26,10 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, VideoInput
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...video_utils import VideoInput
|
||||
|
||||
|
||||
class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
|
||||
@ -64,17 +65,20 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
video_processor ([`Qwen2_5_VLVideoProcessor`], *optional*):
|
||||
The video processor is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
attributes = ["image_processor", "tokenizer", "video_processor"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
video_processor_class = "AutoVideoProcessor"
|
||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
||||
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
|
||||
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
||||
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
||||
self.image_token_id = (
|
||||
@ -87,7 +91,7 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
|
||||
if getattr(tokenizer, "video_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
)
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -138,37 +142,32 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
image_inputs = videos_inputs = {}
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
|
||||
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_grid_thw = None
|
||||
|
||||
if videos is not None:
|
||||
videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"])
|
||||
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
|
||||
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
|
||||
if isinstance(fps, (int, float)):
|
||||
second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw)
|
||||
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
|
||||
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
|
||||
second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps]
|
||||
second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
|
||||
)
|
||||
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
|
||||
|
||||
else:
|
||||
videos_inputs = {}
|
||||
video_grid_thw = None
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text = text.copy() # below lines change text in-place
|
||||
if image_grid_thw is not None:
|
||||
if images is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
@ -178,8 +177,8 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||
|
||||
if video_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
if videos is not None:
|
||||
merge_length = self.video_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while self.video_token in text[i]:
|
||||
|
@ -36,11 +36,9 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
VideoInput,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_batched_videos,
|
||||
make_flat_list_of_images,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
@ -48,6 +46,7 @@ from ...image_utils import (
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, logging
|
||||
from ...video_utils import VideoInput, make_batched_videos
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -407,8 +406,6 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
|
||||
if images is not None:
|
||||
images = make_flat_list_of_images(images)
|
||||
if videos is not None:
|
||||
videos = make_batched_videos(videos)
|
||||
|
||||
if images is not None and not valid_images(images):
|
||||
raise ValueError(
|
||||
@ -426,6 +423,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
data = {}
|
||||
if images is not None:
|
||||
pixel_values, vision_grid_thws = [], []
|
||||
for image in images:
|
||||
@ -450,10 +448,17 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
vision_grid_thws.append(image_grid_thw)
|
||||
pixel_values = np.array(pixel_values)
|
||||
vision_grid_thws = np.array(vision_grid_thws)
|
||||
data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
|
||||
data.update({"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws})
|
||||
|
||||
# kept for BC only and should be removed after v5.0
|
||||
if videos is not None:
|
||||
pixel_values, vision_grid_thws = [], []
|
||||
logger.warning(
|
||||
"`Qwen2VLImageProcessor` works only with image inputs and doesn't process videos anymore. "
|
||||
"This is a deprecated behavior and will be removed in v5.0. "
|
||||
"Your videos should be forwarded to `Qwen2VLVideoProcessor`. "
|
||||
)
|
||||
videos = make_batched_videos(videos)
|
||||
pixel_values_videos, vision_grid_thws_videos = [], []
|
||||
for images in videos:
|
||||
patches, video_grid_thw = self._preprocess(
|
||||
images,
|
||||
@ -472,11 +477,14 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
pixel_values.extend(patches)
|
||||
vision_grid_thws.append(video_grid_thw)
|
||||
pixel_values = np.array(pixel_values)
|
||||
vision_grid_thws = np.array(vision_grid_thws)
|
||||
data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
|
||||
pixel_values_videos.extend(patches)
|
||||
vision_grid_thws_videos.append(video_grid_thw)
|
||||
data.update(
|
||||
{
|
||||
"pixel_values_videos": np.array(pixel_values_videos),
|
||||
"video_grid_thw": np.array(vision_grid_thws_videos),
|
||||
}
|
||||
)
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
@ -35,9 +35,7 @@ from ...image_utils import (
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
VideoInput,
|
||||
get_image_size,
|
||||
make_batched_videos,
|
||||
make_flat_list_of_images,
|
||||
valid_images,
|
||||
)
|
||||
@ -50,6 +48,7 @@ from ...utils import (
|
||||
is_torchvision_v2_available,
|
||||
logging,
|
||||
)
|
||||
from ...video_utils import VideoInput, make_batched_videos
|
||||
from .image_processing_qwen2_vl import smart_resize
|
||||
|
||||
|
||||
@ -334,8 +333,6 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
if images is not None:
|
||||
images = make_flat_list_of_images(images)
|
||||
if videos is not None:
|
||||
videos = make_batched_videos(videos)
|
||||
|
||||
if images is not None and not valid_images(images):
|
||||
raise ValueError(
|
||||
@ -343,6 +340,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
data = {}
|
||||
if images is not None:
|
||||
pixel_values, vision_grid_thws = [], []
|
||||
for image in images:
|
||||
@ -367,10 +365,17 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
vision_grid_thws.append(image_grid_thw)
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
vision_grid_thws = torch.tensor(vision_grid_thws)
|
||||
data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
|
||||
data.update({"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws})
|
||||
|
||||
# kept for BC only and should be removed after v5.0
|
||||
if videos is not None:
|
||||
pixel_values, vision_grid_thws = [], []
|
||||
logger.warning(
|
||||
"`Qwen2VLImageProcessorFast` works only with image inputs and doesn't process videos anymore. "
|
||||
"This is a deprecated behavior and will be removed in v5.0. "
|
||||
"Your videos should be forwarded to `Qwen2VLVideoProcessor`. "
|
||||
)
|
||||
videos = make_batched_videos(videos)
|
||||
pixel_values_videos, vision_grid_thws_videos = [], []
|
||||
for images in videos:
|
||||
patches, video_grid_thw = self._preprocess(
|
||||
images,
|
||||
@ -389,11 +394,11 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
pixel_values.extend(patches)
|
||||
vision_grid_thws.append(video_grid_thw)
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
vision_grid_thws = torch.tensor(vision_grid_thws)
|
||||
data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
|
||||
pixel_values_videos.extend(patches)
|
||||
vision_grid_thws_videos.append(video_grid_thw)
|
||||
pixel_values_videos = torch.stack(pixel_values_videos)
|
||||
vision_grid_thws_videos = torch.tensor(vision_grid_thws_videos)
|
||||
data.update({"pixel_values_videos": pixel_values_videos, "video_grid_thw": vision_grid_thws_videos})
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
@ -24,10 +24,11 @@ Processor class for Qwen2-VL.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, VideoInput
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
from ...video_utils import VideoInput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -60,16 +61,19 @@ class Qwen2VLProcessor(ProcessorMixin):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
video_processor ([`Qwen2VLVideoProcessor`], *optional*):
|
||||
The video processor is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
attributes = ["image_processor", "tokenizer", "video_processor"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
video_processor_class = "AutoVideoProcessor"
|
||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
||||
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
|
||||
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
||||
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
||||
self.image_token_id = (
|
||||
@ -82,7 +86,7 @@ class Qwen2VLProcessor(ProcessorMixin):
|
||||
if getattr(tokenizer, "video_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
)
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -132,26 +136,22 @@ class Qwen2VLProcessor(ProcessorMixin):
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
image_inputs = videos_inputs = {}
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"])
|
||||
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_grid_thw = None
|
||||
|
||||
if videos is not None:
|
||||
videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["videos_kwargs"])
|
||||
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
else:
|
||||
videos_inputs = {}
|
||||
video_grid_thw = None
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text = text.copy() # below lines change text in-place
|
||||
|
||||
if image_grid_thw is not None:
|
||||
if images is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
@ -161,8 +161,8 @@ class Qwen2VLProcessor(ProcessorMixin):
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||
|
||||
if video_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
if videos is not None:
|
||||
merge_length = self.video_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while self.video_token in text[i]:
|
||||
|
208
src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py
Normal file
208
src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py
Normal file
@ -0,0 +1,208 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""video processor class for Qwen2-VL."""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...image_processing_utils import (
|
||||
BatchFeature,
|
||||
)
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
SizeDict,
|
||||
get_image_size,
|
||||
)
|
||||
from ...processing_utils import Unpack, VideosKwargs
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_processing_utils import (
|
||||
BASE_VIDEO_PROCESSOR_DOCSTRING,
|
||||
BaseVideoProcessor,
|
||||
)
|
||||
from ...video_utils import group_videos_by_shape, reorder_videos
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
from .image_processing_qwen2_vl import smart_resize
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class Qwen2VLVideoProcessorInitKwargs(VideosKwargs):
|
||||
min_pixels: Optional[int]
|
||||
max_pixels: Optional[int]
|
||||
patch_size: Optional[int]
|
||||
temporal_patch_size: Optional[int]
|
||||
merge_size: Optional[int]
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast Qwen2-VL image processor that dynamically resizes videos based on the original videos.",
|
||||
BASE_VIDEO_PROCESSOR_DOCSTRING,
|
||||
"""
|
||||
min_pixels (`int`, *optional*, defaults to `56 * 56`):
|
||||
The min pixels of the image to resize the image.
|
||||
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
|
||||
The max pixels of the image to resize the image.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The spacial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to 2):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to 2):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
""",
|
||||
)
|
||||
@requires(backends=("torchvision",))
|
||||
class Qwen2VLVideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
min_pixels = 56 * 56
|
||||
max_pixels = 28 * 28 * 1280
|
||||
patch_size = 14
|
||||
temporal_patch_size = 2
|
||||
merge_size = 2
|
||||
valid_kwargs = Qwen2VLVideoProcessorInitKwargs
|
||||
model_input_names = ["pixel_values_videos", "video_grid_thw"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Qwen2VLVideoProcessorInitKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
self.size = {"shortest_edge": self.min_pixels, "longest_edge": self.max_pixels}
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
videos: List["torch.Tensor"],
|
||||
do_convert_rgb: bool,
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
temporal_patch_size: Optional[int] = None,
|
||||
merge_size: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Group videos by size for batched resizing
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||
resized_videos_grouped = {}
|
||||
for shape, stacked_videos in grouped_videos.items():
|
||||
height, width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST)
|
||||
resized_height, resized_width = height, width
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
stacked_videos = F.resize(
|
||||
stacked_videos, size=(resized_height, resized_width), interpolation=interpolation
|
||||
)
|
||||
resized_videos_grouped[shape] = stacked_videos
|
||||
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
|
||||
|
||||
# Group videos by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns videos with different sizes
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
|
||||
processed_videos_grouped = {}
|
||||
processed_grids = {}
|
||||
for shape, stacked_videos in grouped_videos.items():
|
||||
resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST)
|
||||
|
||||
# Fused rescale and normalize
|
||||
stacked_videos = self.rescale_and_normalize(
|
||||
stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
patches = stacked_videos
|
||||
|
||||
# Check that videos have `num_frames` divisible by `temporal_patch_size`
|
||||
if patches.shape[1] % temporal_patch_size != 0:
|
||||
repeats = patches[:, -1:].repeat(1, self.temporal_patch_size - 1, 1, 1, 1)
|
||||
patches = torch.cat([patches, repeats], dim=1)
|
||||
|
||||
batch_size, grid_t, channel = patches.shape[:3]
|
||||
grid_t = grid_t // temporal_patch_size
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
|
||||
patches = patches.view(
|
||||
batch_size,
|
||||
grid_t,
|
||||
temporal_patch_size,
|
||||
channel,
|
||||
grid_h // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
grid_w // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
)
|
||||
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
|
||||
flatten_patches = patches.reshape(
|
||||
batch_size,
|
||||
grid_t * grid_h * grid_w,
|
||||
channel * temporal_patch_size * patch_size * patch_size,
|
||||
)
|
||||
|
||||
processed_videos_grouped[shape] = flatten_patches
|
||||
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
|
||||
|
||||
processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
|
||||
processed_grids = reorder_videos(processed_grids, grouped_videos_index)
|
||||
pixel_values_videos = torch.cat(processed_videos, dim=0)
|
||||
video_grid_thw = torch.tensor(processed_grids)
|
||||
|
||||
return BatchFeature(
|
||||
data={"pixel_values_videos": pixel_values_videos, "video_grid_thw": video_grid_thw},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Qwen2VLVideoProcessor"]
|
@ -21,10 +21,11 @@ from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_utils import ImageInput, VideoInput
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
|
||||
from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput
|
||||
from ...utils import is_tf_available, is_torch_available
|
||||
from ...video_utils import VideoInput
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
@ -21,10 +21,11 @@ from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_utils import ImageInput, VideoInput
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput
|
||||
from ...utils import is_torch_available
|
||||
from ...video_utils import VideoInput
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
@ -23,23 +23,20 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import (
|
||||
ImageInput,
|
||||
VideoInput,
|
||||
load_video,
|
||||
make_batched_videos,
|
||||
make_nested_list_of_images,
|
||||
)
|
||||
from ...image_utils import ImageInput, make_nested_list_of_images
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import BatchEncoding, TextInput
|
||||
from ...utils import is_num2words_available, logging
|
||||
from .video_processing_smolvlm import (
|
||||
from ...utils import is_num2words_available, is_vision_available, logging
|
||||
from ...video_utils import VideoInput, load_video, make_batched_videos
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from .video_processing_smolvlm import (
|
||||
DEFAULT_MEDIA_OUTTRO,
|
||||
DEFAULT_VIDEO_INTRO,
|
||||
FRAME_TIMESTAMP_MESSAGE,
|
||||
smolvlm_sample_indices_fn,
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...tokenization_utils_base import PreTokenizedInput
|
||||
@ -129,8 +126,10 @@ class SmolVLMProcessor(ProcessorMixin):
|
||||
Args:
|
||||
image_processor (`SmolVLMImageProcessor`):
|
||||
An instance of [`SmolVLMImageProcessor`]. The image processor is a required input.
|
||||
tokenizer (`PreTrainedTokenizerBase`, *optional*):
|
||||
tokenizer (`PreTrainedTokenizerBase`):
|
||||
An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
|
||||
video_processor (`SmolVLMImageProcessor`):
|
||||
n instance of [`SmolVLMImageProcessor`]. The video processor is a required input.
|
||||
image_seq_len (`int`, *optional*, defaults to 169):
|
||||
The length of the image sequence i.e. the number of <image> tokens per image in the input.
|
||||
This parameter is used to build the string from the input prompt and image tokens and should match the
|
||||
@ -139,13 +138,22 @@ class SmolVLMProcessor(ProcessorMixin):
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
attributes = ["image_processor", "tokenizer", "video_processor"]
|
||||
valid_kwargs = ["image_seq_len", "chat_template"]
|
||||
image_processor_class = "SmolVLMImageProcessor"
|
||||
video_processor_class = (
|
||||
"SmolVLMImageProcessor" # TODO: raushan should be VideoProcessor when LANCZOS resizing is settled
|
||||
)
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self, image_processor, tokenizer=None, image_seq_len: int = 169, chat_template: Optional[str] = None, **kwargs
|
||||
self,
|
||||
image_processor,
|
||||
tokenizer,
|
||||
video_processor,
|
||||
image_seq_len: int = 169,
|
||||
chat_template: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.fake_image_token = getattr(tokenizer, "fake_image_token", "<fake_token_around_image>")
|
||||
self.image_token = getattr(tokenizer, "image_token", "<image>")
|
||||
@ -154,14 +162,14 @@ class SmolVLMProcessor(ProcessorMixin):
|
||||
self.global_image_token = getattr(tokenizer, "global_image_token", "<global-img>")
|
||||
self.image_seq_len = image_seq_len
|
||||
|
||||
self.video_size = image_processor.video_sampling["video_size"]
|
||||
self.video_size = video_processor.video_sampling["video_size"]
|
||||
self.image_size = image_processor.size
|
||||
|
||||
self.do_image_splitting = image_processor.do_image_splitting
|
||||
self.do_video_splitting = image_processor.video_sampling.get("do_image_splitting", False)
|
||||
self.do_video_splitting = video_processor.video_sampling.get("do_image_splitting", False)
|
||||
|
||||
self.default_max_frames = image_processor.video_sampling["max_frames"]
|
||||
self.default_fps = image_processor.video_sampling["fps"]
|
||||
self.default_max_frames = video_processor.video_sampling["max_frames"]
|
||||
self.default_fps = video_processor.video_sampling["fps"]
|
||||
# Matches one or more occurrences of <row_x_col_y> tags (where x and y are digits, optionally surrounded by newline characters
|
||||
# self._regex_to_remove_extra_special_tokens = re.compile(r"(<row_\d+_col_\d+>\n?)+")
|
||||
|
||||
@ -170,7 +178,7 @@ class SmolVLMProcessor(ProcessorMixin):
|
||||
"Package `num2words` is required to run SmolVLM processor. Install it with `pip install num2words`."
|
||||
)
|
||||
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
|
||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
|
||||
|
||||
def process_vision(self, text, images, output_kwargs, do_image_splitting=False, image_processor_size=None):
|
||||
if text is not None:
|
||||
@ -266,6 +274,9 @@ class SmolVLMProcessor(ProcessorMixin):
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
Wherever an image token, `<image>` is encountered it is expanded to
|
||||
`<fake_token_around_image>` + `<row_x_col_y>` + `<image>` * `image_seq_len` * <fake_token_around_image>`.
|
||||
videos (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
|
||||
The video or batch of videos to be prepared. Each video can be a list of PIL frames, NumPy array or PyTorch
|
||||
tensor. If is of type `List[VideoInput]`, it's assumed that this is for a single prompt i.e. of batch size 1.
|
||||
return_tensors (`Union[str, TensorType]`, *optional*):
|
||||
If set, will return tensors of a particular framework. See [`PreTrainedTokenizerFast.__call__`] for more
|
||||
information.
|
||||
|
@ -14,9 +14,46 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Make sure these are imported from your library
|
||||
from ...image_processing_utils import (
|
||||
BatchFeature,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
SizeDict,
|
||||
)
|
||||
from ...processing_utils import Unpack, VideosKwargs
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_processing_utils import (
|
||||
BaseVideoProcessor,
|
||||
)
|
||||
from ...video_utils import group_videos_by_shape, reorder_videos
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -28,6 +65,7 @@ DEFAULT_VIDEO_INTRO = (
|
||||
)
|
||||
DEFAULT_MEDIA_OUTTRO = "\n\n"
|
||||
FRAME_TIMESTAMP_MESSAGE = "\nFrame from {timestamp}:"
|
||||
MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
|
||||
|
||||
|
||||
def smolvlm_sample_indices_fn(metadata, max_frames, target_fps, skip_secs=0):
|
||||
@ -88,3 +126,221 @@ def smolvlm_sample_indices_fn(metadata, max_frames, target_fps, skip_secs=0):
|
||||
indices = np.unique(indices)
|
||||
|
||||
return indices
|
||||
|
||||
|
||||
def get_max_height_width(videos: list["torch.Tensor"]) -> List[int]:
|
||||
"""
|
||||
Get the maximum height and width across all videos in a batch.
|
||||
"""
|
||||
max_height = max_width = float("-inf")
|
||||
for video in videos:
|
||||
height, width = video.size()[-2:]
|
||||
max_height = max(height, max_height)
|
||||
max_width = max(width, max_width)
|
||||
return (max_height, max_width)
|
||||
|
||||
|
||||
def get_resize_output_image_size(
|
||||
video,
|
||||
resolution_max_side: int,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Get the output size of the video after resizing given a dictionary specifying the max and min sizes.
|
||||
Args:
|
||||
video (`np.ndarray`):
|
||||
Video to resize.
|
||||
resolution_max_side (`int`):
|
||||
The longest edge of the video will be resized to this value. The shortest edge will be resized to keep the
|
||||
input aspect ratio.
|
||||
Returns:
|
||||
The output size of the video after resizing.
|
||||
"""
|
||||
height, width = video.size()[-2:]
|
||||
|
||||
# Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
|
||||
# The output size must be below the MAX_IMAGE_SIZE
|
||||
resolution_max_side = min(MAX_IMAGE_SIZE, resolution_max_side)
|
||||
resolution_max_side = max(height, width) if resolution_max_side is None else resolution_max_side
|
||||
aspect_ratio = width / height
|
||||
|
||||
if width >= height:
|
||||
width = resolution_max_side
|
||||
height = int(width / aspect_ratio)
|
||||
if height % 2 != 0:
|
||||
height += 1
|
||||
elif height > width:
|
||||
height = resolution_max_side
|
||||
width = int(height * aspect_ratio)
|
||||
if width % 2 != 0:
|
||||
width += 1
|
||||
|
||||
height = max(height, 1)
|
||||
width = max(width, 1)
|
||||
|
||||
return height, width
|
||||
|
||||
|
||||
class SmolVLMVideoProcessorInitKwargs(VideosKwargs): ...
|
||||
|
||||
|
||||
@requires(backends=("torchvision",))
|
||||
class SmolVLMVideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.LANCZOS
|
||||
size = {"longest_edge": 4 * 364}
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
do_pad = True
|
||||
valid_kwargs = SmolVLMVideoProcessorInitKwargs
|
||||
model_input_names = ["pixel_values", "pixel_attention_mask"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[SmolVLMVideoProcessorInitKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def resize(
|
||||
self,
|
||||
video: "torch.Tensor",
|
||||
size: SizeDict,
|
||||
interpolation: "F.InterpolationMode" = None,
|
||||
antialias: bool = True,
|
||||
**kwargs,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Resize an video to `(size["height"], size["width"])`.
|
||||
Args:
|
||||
video (`torch.Tensor`):
|
||||
Video to resize.
|
||||
size (`SizeDict`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output video.
|
||||
resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||
`InterpolationMode` filter to use when resizing the video e.g. `InterpolationMode.BICUBIC`.
|
||||
Returns:
|
||||
`torch.Tensor`: The resized video.
|
||||
"""
|
||||
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
|
||||
if interpolation == F.InterpolationMode.LANCZOS:
|
||||
logger.warning_once(
|
||||
"You have used fast image processor with LANCZOS resample which not yet supported for torch.Tensor. "
|
||||
"BICUBIC resample will be used as an alternative. Please fall back to image processor if you "
|
||||
"want full consistency with the original model."
|
||||
)
|
||||
interpolation = F.InterpolationMode.BICUBIC
|
||||
|
||||
if size.longest_edge:
|
||||
# Resize the image so that the shortest edge or the longest edge is of the given size
|
||||
# while maintaining the aspect ratio of the original image.
|
||||
new_size = get_resize_output_image_size(
|
||||
video,
|
||||
resolution_max_side=size.longest_edge,
|
||||
)
|
||||
elif size.height and size.width:
|
||||
new_size = (size.height, size.width)
|
||||
else:
|
||||
raise ValueError(f"Size must contain 'height' and 'width' keys, or 'longest_edge' key. Got {size}.")
|
||||
return F.resize(video, new_size, interpolation=interpolation, antialias=antialias)
|
||||
|
||||
def pad(
|
||||
self,
|
||||
video: "torch.Tensor",
|
||||
padded_size: tuple[int, int],
|
||||
fill: int = 0,
|
||||
return_pixel_mask: bool = True,
|
||||
):
|
||||
"""Pads the sample with empty video to the padded_size
|
||||
Args:
|
||||
video (`torch.Tensor`):
|
||||
Video to pad.
|
||||
padded_size (`Tuple[int, int]`):
|
||||
Height and width to pad.
|
||||
fill (`int`, *optional*):
|
||||
The value to use for the padding.
|
||||
return_pixel_mask (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a pixel mask.
|
||||
"""
|
||||
original_size = video.size()[-2:]
|
||||
padding_bottom = padded_size[0] - original_size[0]
|
||||
padding_right = padded_size[1] - original_size[1]
|
||||
if padding_bottom < 0 or padding_right < 0:
|
||||
raise ValueError(
|
||||
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
|
||||
f"original size. Got padded size: {padded_size}, original size: {original_size}."
|
||||
)
|
||||
if original_size != padded_size:
|
||||
padding = [0, 0, padding_right, padding_bottom]
|
||||
video = F.pad(video, padding, fill=fill)
|
||||
|
||||
# Make a pixel mask for the video, where 1 indicates a valid pixel and 0 indicates padding.
|
||||
pixel_mask = None
|
||||
if return_pixel_mask:
|
||||
pixel_mask = torch.zeros_like(video[..., 0, :, :], dtype=torch.int64)
|
||||
pixel_mask[..., : original_size[0], : original_size[1]] = 1
|
||||
|
||||
return video, pixel_mask
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
videos: List["torch.Tensor"],
|
||||
do_convert_rgb: bool,
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
do_pad: bool,
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Group videos by size for batched resizing
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||
resized_videos_grouped = {}
|
||||
for shape, stacked_videos in grouped_videos.items():
|
||||
if do_convert_rgb:
|
||||
stacked_videos = self.convert_to_rgb(stacked_videos)
|
||||
if do_resize:
|
||||
stacked_videos = self.resize(stacked_videos, size=size, interpolation=interpolation)
|
||||
resized_videos_grouped[shape] = stacked_videos
|
||||
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
|
||||
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
|
||||
processed_videos_grouped = {}
|
||||
for shape, stacked_videos in grouped_videos.items():
|
||||
stacked_videos = self.rescale_and_normalize(
|
||||
stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_videos_grouped[shape] = stacked_videos
|
||||
|
||||
processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
|
||||
|
||||
if do_pad:
|
||||
pad_size = get_max_height_width(processed_videos)
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos)
|
||||
processed_padded_mask_grouped = {}
|
||||
processed_videos_grouped = {}
|
||||
|
||||
for shape, stacked_videos in grouped_videos.items():
|
||||
stacked_videos, padded_masks = self.pad(stacked_videos, padded_size=pad_size)
|
||||
processed_videos_grouped[shape] = stacked_videos
|
||||
processed_padded_mask_grouped[shape] = padded_masks
|
||||
|
||||
processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
|
||||
pixel_attention_mask = reorder_videos(processed_padded_mask_grouped, grouped_videos_index)
|
||||
|
||||
processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos
|
||||
data = {"pixel_values": processed_videos}
|
||||
|
||||
if do_pad:
|
||||
data["pixel_attention_mask"] = (
|
||||
torch.stack(pixel_attention_mask, dim=0)
|
||||
if do_pad and return_tensors is not None
|
||||
else pixel_attention_mask
|
||||
)
|
||||
return BatchFeature(data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["SmolVLMVideoProcessor"]
|
||||
|
@ -31,16 +31,15 @@ from ...image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
VideoInput,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_batched_videos,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, filter_out_non_signature_kwargs, logging
|
||||
from ...video_utils import VideoInput, make_batched_videos
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -259,10 +258,8 @@ class VideoLlavaImageProcessor(BaseImageProcessor):
|
||||
|
||||
if images is not None:
|
||||
images = make_list_of_images(images)
|
||||
if videos is not None:
|
||||
videos = make_batched_videos(videos)
|
||||
|
||||
if (videos is not None and not valid_images(videos)) or (images is not None and not valid_images(images)):
|
||||
if images is not None and not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid input type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
@ -270,6 +267,12 @@ class VideoLlavaImageProcessor(BaseImageProcessor):
|
||||
|
||||
data = {}
|
||||
if videos is not None:
|
||||
logger.warning(
|
||||
"`VideoLlavaImageProcessor` works only with image inputs and doesn't process videos anymore. "
|
||||
"This is a deprecated behavior and will be removed in v5.0. "
|
||||
"Your videos should be forwarded to `VideoLlavaVideoProcessor`. "
|
||||
)
|
||||
videos = make_batched_videos(videos)
|
||||
pixel_values_videos = [
|
||||
[
|
||||
self._preprocess_image(
|
||||
|
@ -40,6 +40,8 @@ class VideoLlavaProcessor(ProcessorMixin):
|
||||
Args:
|
||||
image_processor ([`VideoLlavaImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
video_processor ([`VideoLlavaVideoProcessor`], *optional*):
|
||||
The video processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
@ -58,7 +60,7 @@ class VideoLlavaProcessor(ProcessorMixin):
|
||||
extra tokens appended, no need to set this arg.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
attributes = ["image_processor", "video_processor", "tokenizer"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"patch_size",
|
||||
@ -68,11 +70,13 @@ class VideoLlavaProcessor(ProcessorMixin):
|
||||
"num_additional_image_tokens",
|
||||
]
|
||||
image_processor_class = "VideoLlavaImageProcessor"
|
||||
video_processor_class = "AutoVideoProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
video_processor=None,
|
||||
tokenizer=None,
|
||||
patch_size=14,
|
||||
vision_feature_select_strategy="default",
|
||||
@ -89,7 +93,7 @@ class VideoLlavaProcessor(ProcessorMixin):
|
||||
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
|
||||
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
self.video_token_id = tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
super().__init__(image_processor, video_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -150,54 +154,52 @@ class VideoLlavaProcessor(ProcessorMixin):
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values to be fed to a model. Returned when `videos` is not `None`.
|
||||
"""
|
||||
data = {}
|
||||
if images is not None or videos is not None:
|
||||
encoded_images = self.image_processor(images=images, videos=videos, return_tensors=return_tensors)
|
||||
if images is not None:
|
||||
encoded_images = self.image_processor(images=images, return_tensors=return_tensors)
|
||||
data.update(encoded_images)
|
||||
|
||||
if videos is not None:
|
||||
encoded_videos = self.video_processor(videos=videos, return_tensors=return_tensors)
|
||||
data.update(encoded_videos)
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
prompt_strings = text
|
||||
|
||||
if encoded_images is not None:
|
||||
if "pixel_values_images" in encoded_images.keys():
|
||||
height, width = get_image_size(to_numpy_array(encoded_images.get("pixel_values_images")[0]))
|
||||
num_frames = 1
|
||||
num_image_tokens = (height // self.patch_size) * (width // self.patch_size)
|
||||
num_image_tokens += self.num_additional_image_tokens
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
text = [sample.replace(self.image_token, self.image_token * num_image_tokens) for sample in text]
|
||||
|
||||
if "pixel_values_videos" in encoded_images.keys():
|
||||
one_video = encoded_images.get("pixel_values_videos")[0]
|
||||
if isinstance(encoded_images.get("pixel_values_videos")[0], (list, tuple)):
|
||||
if encoded_videos is not None:
|
||||
one_video = encoded_videos.get("pixel_values_videos")[0]
|
||||
if isinstance(encoded_videos.get("pixel_values_videos")[0], (list, tuple)):
|
||||
one_video = np.array(one_video)
|
||||
else:
|
||||
one_video = to_numpy_array(one_video)
|
||||
height, width = get_image_size(one_video[0])
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
|
||||
num_image_tokens = (height // self.patch_size) * (
|
||||
width // self.patch_size
|
||||
) + self.num_additional_image_tokens
|
||||
num_image_tokens = (height // self.patch_size) * (width // self.patch_size)
|
||||
num_image_tokens += self.num_additional_image_tokens
|
||||
num_video_tokens = num_image_tokens * num_frames
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
|
||||
prompt_strings = []
|
||||
for sample in text:
|
||||
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
|
||||
sample = sample.replace(self.video_token, self.video_token * num_video_tokens)
|
||||
prompt_strings.append(sample)
|
||||
text = [sample.replace(self.video_token, self.video_token * num_video_tokens) for sample in text]
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt_strings,
|
||||
text,
|
||||
return_tensors=None,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image", "video"])
|
||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
||||
|
||||
data.update(text_inputs)
|
||||
|
||||
|
@ -0,0 +1,56 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Video processor class for Video-LLaVA."""
|
||||
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
)
|
||||
from ...processing_utils import Unpack, VideosKwargs
|
||||
from ...utils import is_vision_available
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_processing_utils import (
|
||||
BaseVideoProcessor,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
|
||||
class VideoLlavaFastVideoProcessorInitKwargs(VideosKwargs): ...
|
||||
|
||||
|
||||
@requires(backends=("torchvision",))
|
||||
class VideoLlavaVideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
size = {"shortest_edge": 224}
|
||||
default_to_square = False
|
||||
crop_size = {"height": 224, "width": 224}
|
||||
do_resize = True
|
||||
do_center_crop = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
valid_kwargs = VideoLlavaFastVideoProcessorInitKwargs
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[VideoLlavaFastVideoProcessorInitKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["VideoLlavaVideoProcessor"]
|
@ -32,16 +32,9 @@ from huggingface_hub.errors import EntryNotFoundError
|
||||
from .audio_utils import load_audio
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .feature_extraction_utils import BatchFeature
|
||||
from .image_utils import (
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
VideoInput,
|
||||
is_valid_image,
|
||||
is_vision_available,
|
||||
load_image,
|
||||
load_video,
|
||||
)
|
||||
from .image_utils import ChannelDimension, ImageInput, is_valid_image, is_vision_available, load_image
|
||||
from .utils.chat_template_utils import render_jinja_template
|
||||
from .video_utils import VideoInput, load_video
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@ -84,6 +77,7 @@ AUTO_TO_BASE_CLASS_MAPPING = {
|
||||
"AutoTokenizer": "PreTrainedTokenizerBase",
|
||||
"AutoFeatureExtractor": "FeatureExtractionMixin",
|
||||
"AutoImageProcessor": "ImageProcessingMixin",
|
||||
"AutoVideoProcessor": "BaseVideoProcessor",
|
||||
}
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
@ -193,7 +187,7 @@ class ImagesKwargs(TypedDict, total=False):
|
||||
do_resize: Optional[bool]
|
||||
size: Optional[dict[str, int]]
|
||||
size_divisor: Optional[int]
|
||||
crop_size: Optional[dict[str, int]]
|
||||
crop_size: Optional[Dict[str, int]]
|
||||
resample: Optional[Union["PILImageResampling", int]]
|
||||
do_rescale: Optional[bool]
|
||||
rescale_factor: Optional[float]
|
||||
@ -213,37 +207,45 @@ class VideosKwargs(TypedDict, total=False):
|
||||
Keyword arguments for video processing.
|
||||
|
||||
Attributes:
|
||||
do_convert_rgb (`bool`):
|
||||
Whether to convert the video to RGB fromat.
|
||||
do_resize (`bool`):
|
||||
Whether to resize the image.
|
||||
Whether to resize the video.
|
||||
size (`Dict[str, int]`, *optional*):
|
||||
Resize the shorter side of the input to `size["shortest_edge"]`.
|
||||
default_to_square (`bool`, *optional*, defaults to `self.default_to_square`):
|
||||
Whether to default to a square when resizing, if size is an int.
|
||||
size_divisor (`int`, *optional*):
|
||||
The size by which to make sure both the height and width can be divided.
|
||||
resample (`PILImageResampling`, *optional*):
|
||||
Resampling filter to use if resizing the image.
|
||||
Resampling filter to use if resizing the video.
|
||||
do_rescale (`bool`, *optional*):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`.
|
||||
Whether to rescale the video by the specified scale `rescale_factor`.
|
||||
rescale_factor (`int` or `float`, *optional*):
|
||||
Scale factor to use if rescaling the image.
|
||||
Scale factor to use if rescaling the video.
|
||||
do_normalize (`bool`, *optional*):
|
||||
Whether to normalize the image.
|
||||
Whether to normalize the video.
|
||||
image_mean (`float` or `List[float]`, *optional*):
|
||||
Mean to use if normalizing the image.
|
||||
Mean to use if normalizing the video.
|
||||
image_std (`float` or `List[float]`, *optional*):
|
||||
Standard deviation to use if normalizing the image.
|
||||
Standard deviation to use if normalizing the video.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image to the `(max_height, max_width)` of the images in the batch.
|
||||
Whether to pad the video to the `(max_height, max_width)` of the videos in the batch.
|
||||
do_center_crop (`bool`, *optional*):
|
||||
Whether to center crop the image.
|
||||
Whether to center crop the video.
|
||||
crop_size (`Dict[str, int]`, *optional*):
|
||||
Desired output size when applying center-cropping.
|
||||
data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the output image.
|
||||
The channel dimension format for the output video.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image.
|
||||
The channel dimension format for the input video.
|
||||
"""
|
||||
|
||||
do_convert_rgb: Optional[bool]
|
||||
do_resize: Optional[bool]
|
||||
size: Optional[dict[str, int]]
|
||||
size_divisor: Optional[int]
|
||||
default_to_square: Optional[bool]
|
||||
resample: Optional["PILImageResampling"]
|
||||
do_rescale: Optional[bool]
|
||||
rescale_factor: Optional[float]
|
||||
@ -252,8 +254,10 @@ class VideosKwargs(TypedDict, total=False):
|
||||
image_std: Optional[Union[float, list[float]]]
|
||||
do_pad: Optional[bool]
|
||||
do_center_crop: Optional[bool]
|
||||
crop_size: Optional[Dict[str, int]]
|
||||
data_format: Optional[ChannelDimension]
|
||||
input_data_format: Optional[Union[str, ChannelDimension]]
|
||||
device: Optional[str]
|
||||
|
||||
|
||||
class AudioKwargs(TypedDict, total=False):
|
||||
@ -532,6 +536,8 @@ class ProcessorMixin(PushToHubMixin):
|
||||
del output["tokenizer"]
|
||||
if "image_processor" in output:
|
||||
del output["image_processor"]
|
||||
if "video_processor" in output:
|
||||
del output["video_processor"]
|
||||
if "feature_extractor" in output:
|
||||
del output["feature_extractor"]
|
||||
if "chat_template" in output:
|
||||
@ -1248,6 +1254,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
return getattr(transformers_module, module_name)
|
||||
lookup_locations = [
|
||||
transformers_module.IMAGE_PROCESSOR_MAPPING,
|
||||
transformers_module.VIDEO_PROCESSOR_MAPPING,
|
||||
transformers_module.TOKENIZER_MAPPING,
|
||||
transformers_module.FEATURE_EXTRACTOR_MAPPING,
|
||||
]
|
||||
|
@ -78,6 +78,7 @@ from .utils import (
|
||||
is_compressed_tensors_available,
|
||||
is_cv2_available,
|
||||
is_cython_available,
|
||||
is_decord_available,
|
||||
is_detectron2_available,
|
||||
is_eetq_available,
|
||||
is_essentia_available,
|
||||
@ -1247,6 +1248,13 @@ def require_av(test_case):
|
||||
return unittest.skipUnless(is_av_available(), "test requires av")(test_case)
|
||||
|
||||
|
||||
def require_decord(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires decord
|
||||
"""
|
||||
return unittest.skipUnless(is_decord_available(), "test requires decord")(test_case)
|
||||
|
||||
|
||||
def require_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed.
|
||||
|
@ -285,7 +285,8 @@ SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||
CONFIG_NAME = "config.json"
|
||||
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
|
||||
IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
|
||||
IMAGE_PROCESSOR_NAME = "preprocessor_config.json"
|
||||
VIDEO_PROCESSOR_NAME = "video_preprocessor_config.json"
|
||||
PROCESSOR_NAME = "processor_config.json"
|
||||
GENERATION_CONFIG_NAME = "generation_config.json"
|
||||
MODEL_CARD_NAME = "modelcard.json"
|
||||
|
@ -7,3 +7,10 @@ class BaseImageProcessorFast(metaclass=DummyObject):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
||||
|
||||
class BaseVideoProcessor(metaclass=DummyObject):
|
||||
_backends = ["torchvision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
800
src/transformers/video_processing_utils.py
Normal file
800
src/transformers/video_processing_utils.py
Normal file
@ -0,0 +1,800 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .image_processing_utils import (
|
||||
BatchFeature,
|
||||
get_size_dict,
|
||||
)
|
||||
from .image_processing_utils_fast import BaseImageProcessorFast
|
||||
from .image_utils import (
|
||||
ChannelDimension,
|
||||
SizeDict,
|
||||
validate_kwargs,
|
||||
)
|
||||
from .processing_utils import Unpack, VideosKwargs
|
||||
from .utils import (
|
||||
VIDEO_PROCESSOR_NAME,
|
||||
TensorType,
|
||||
add_model_info_to_auto_map,
|
||||
add_model_info_to_custom_pipelines,
|
||||
add_start_docstrings,
|
||||
cached_file,
|
||||
copy_func,
|
||||
download_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
from .utils.import_utils import requires
|
||||
from .video_utils import (
|
||||
VideoInput,
|
||||
group_videos_by_shape,
|
||||
load_video,
|
||||
make_batched_videos,
|
||||
reorder_videos,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from .image_utils import PILImageResampling
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
from .image_utils import pil_torch_interpolation_mapping
|
||||
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
BASE_VIDEO_PROCESSOR_DOCSTRING = r"""
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the video's (height, width) dimensions to the specified `size`. Can be overridden by the
|
||||
`do_resize` parameter in the `preprocess` method.
|
||||
size (`dict`, *optional*, defaults to `self.size`):
|
||||
Size of the output video after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
||||
method.
|
||||
size_divisor (`int`, *optional*, defaults to `self.size_divisor`):
|
||||
The size by which to make sure both the height and width can be divided.
|
||||
default_to_square (`bool`, *optional*, defaults to `self.default_to_square`):
|
||||
Whether to default to a square video when resizing, if size is an int.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the video. Only has an effect if `do_resize` is set to `True`. Can be
|
||||
overridden by the `resample` parameter in the `preprocess` method.
|
||||
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
||||
Whether to center crop the video to the specified `crop_size`. Can be overridden by `do_center_crop` in the
|
||||
`preprocess` method.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the video to the `(max_height, max_width)` of the videos in the batch.
|
||||
crop_size (`Dict[str, int]` *optional*, defaults to `self.crop_size`):
|
||||
Size of the output video after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
|
||||
method.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the video by the specified scale `rescale_factor`. Can be overridden by the
|
||||
`do_rescale` parameter in the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Scale factor to use if rescaling the video. Only has an effect if `do_rescale` is set to `True`. Can be
|
||||
overridden by the `rescale_factor` parameter in the `preprocess` method.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the video. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Mean to use if normalizing the video. This is a float or list of floats the length of the number of
|
||||
channels in the video. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
||||
overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Standard deviation to use if normalizing the video. This is a float or list of floats the length of the
|
||||
number of channels in the video. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.image_std`):
|
||||
Whether to convert the video to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output video. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: video in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: video in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input video.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input video. If unset, the channel dimension format is inferred
|
||||
from the input video. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: video in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: video in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: video in (height, width) format.
|
||||
device (`torch.device`, *optional*):
|
||||
The device to process the videos on. If unset, the device is inferred from the input videos."""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a base VideoProcessor.",
|
||||
BASE_VIDEO_PROCESSOR_DOCSTRING,
|
||||
)
|
||||
@requires(backends=("vision", "torchvision"))
|
||||
class BaseVideoProcessor(BaseImageProcessorFast):
|
||||
_auto_class = None
|
||||
|
||||
resample = None
|
||||
image_mean = None
|
||||
image_std = None
|
||||
size = None
|
||||
size_divisor = None
|
||||
default_to_square = True
|
||||
crop_size = None
|
||||
do_resize = None
|
||||
do_center_crop = None
|
||||
do_pad = None
|
||||
do_rescale = None
|
||||
rescale_factor = 1 / 255
|
||||
do_normalize = None
|
||||
do_convert_rgb = None
|
||||
valid_kwargs = VideosKwargs
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[VideosKwargs]) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._processor_class = kwargs.pop("processor_class", None)
|
||||
|
||||
# Additional attributes without default values
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except AttributeError as err:
|
||||
logger.error(f"Can't set {key} with value {value} for {self}")
|
||||
raise err
|
||||
|
||||
# Prepare size related keys and turn then into `SizeDict`
|
||||
size = kwargs.pop("size", self.size)
|
||||
self.size = (
|
||||
get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
|
||||
if size is not None
|
||||
else None
|
||||
)
|
||||
crop_size = kwargs.pop("crop_size", self.crop_size)
|
||||
self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None
|
||||
|
||||
# Save valid kwargs in a list for further processing
|
||||
self.model_valid_processing_keys = list(self.valid_kwargs.__annotations__.keys())
|
||||
for key in self.model_valid_processing_keys:
|
||||
if kwargs.get(key) is not None:
|
||||
setattr(self, key, kwargs[key])
|
||||
else:
|
||||
setattr(self, key, getattr(self, key, None))
|
||||
|
||||
def __call__(self, videos, **kwargs) -> BatchFeature:
|
||||
return self.preprocess(videos, **kwargs)
|
||||
|
||||
def convert_to_rgb(
|
||||
self,
|
||||
video: "torch.Tensor",
|
||||
) -> VideoInput:
|
||||
"""
|
||||
Converts a video to RGB format.
|
||||
|
||||
Args:
|
||||
video (`"torch.Tensor"`):
|
||||
The video to convert.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The converted video.
|
||||
"""
|
||||
|
||||
video = F.grayscale_to_rgb(video)
|
||||
if video.shape[-3] == 3 or not (video[..., 3, :, :] < 255).any():
|
||||
return video
|
||||
|
||||
# There is a transparency layer, blend it with a white background.
|
||||
# Calculate the alpha proportion for blending.
|
||||
alpha = video[..., 3, :, :] / 255.0
|
||||
video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., :3, :, :]
|
||||
return video
|
||||
|
||||
def _prepare_input_videos(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
) -> List["torch.Tensor"]:
|
||||
"""
|
||||
Prepare the input videos for processing.
|
||||
"""
|
||||
videos = make_batched_videos(videos)
|
||||
processed_videos = []
|
||||
for video in videos:
|
||||
# `make_batched_videos` always returns a 4D array per video
|
||||
if isinstance(video, np.ndarray):
|
||||
video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_data_format)
|
||||
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
|
||||
video = torch.from_numpy(video).contiguous()
|
||||
|
||||
# Now that we have torch tensors, we can move them to the right device
|
||||
if device is not None:
|
||||
video = video.to(device)
|
||||
|
||||
processed_videos.append(video)
|
||||
return processed_videos
|
||||
|
||||
@add_start_docstrings(BASE_VIDEO_PROCESSOR_DOCSTRING)
|
||||
def preprocess(
|
||||
self,
|
||||
videos: VideoInput,
|
||||
**kwargs: Unpack[VideosKwargs],
|
||||
) -> BatchFeature:
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
input_data_format = kwargs.pop("input_data_format")
|
||||
device = kwargs.pop("device")
|
||||
videos = self._prepare_input_videos(videos=videos, input_data_format=input_data_format, device=device)
|
||||
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
self._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
# torch resize uses interpolation instead of resample
|
||||
resample = kwargs.pop("resample")
|
||||
kwargs["interpolation"] = (
|
||||
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
||||
)
|
||||
|
||||
# Pop kwargs that are not needed in _preprocess
|
||||
kwargs.pop("default_to_square")
|
||||
kwargs.pop("data_format")
|
||||
|
||||
return self._preprocess(videos=videos, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
videos: List["torch.Tensor"],
|
||||
do_convert_rgb: bool,
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
size_divisor: Optional[int],
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
do_pad: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> BatchFeature:
|
||||
# Group videos by size for batched resizing
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||
resized_videos_grouped = {}
|
||||
for shape, stacked_videos in grouped_videos.items():
|
||||
if do_convert_rgb:
|
||||
stacked_videos = self.convert_to_rgb(stacked_videos)
|
||||
if do_resize:
|
||||
stacked_videos = self.resize(
|
||||
stacked_videos, size=size, size_divisor=size_divisor, interpolation=interpolation
|
||||
)
|
||||
resized_videos_grouped[shape] = stacked_videos
|
||||
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
|
||||
|
||||
# Group videos by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns videos with different sizes
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
|
||||
processed_videos_grouped = {}
|
||||
for shape, stacked_videos in grouped_videos.items():
|
||||
if do_center_crop:
|
||||
stacked_videos = self.center_crop(stacked_videos, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_videos = self.rescale_and_normalize(
|
||||
stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_videos_grouped[shape] = stacked_videos
|
||||
|
||||
processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
|
||||
processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos
|
||||
|
||||
return BatchFeature(data={"pixel_values_videos": processed_videos}, tensor_type=return_tensors)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
local_files_only: bool = False,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
revision: str = "main",
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Instantiate a type of [`~video_processing_utils.VideoProcessorBase`] from an video processor.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the *model id* of a pretrained video hosted inside a model repo on
|
||||
huggingface.co.
|
||||
- a path to a *directory* containing a video processor file saved using the
|
||||
[`~video_processing_utils.VideoProcessorBase.save_pretrained`] method, e.g.,
|
||||
`./my_model_directory/`.
|
||||
- a path or url to a saved video processor JSON *file*, e.g.,
|
||||
`./my_model_directory/preprocessor_config.json`.
|
||||
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model video processor should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the video processor files and override the cached versions if
|
||||
they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||
Will be removed in v5 of Transformers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
token (`str` or `bool`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
||||
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
|
||||
|
||||
</Tip>
|
||||
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
If `False`, then this function returns just the final video processor object. If `True`, then this
|
||||
functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
|
||||
consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of
|
||||
`kwargs` which has not been used to update `video_processor` and is otherwise ignored.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
The values in kwargs of any keys which are video processor attributes will be used to override the
|
||||
loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is
|
||||
controlled by the `return_unused_kwargs` keyword parameter.
|
||||
|
||||
Returns:
|
||||
A video processor of type [`~video_processing_utils.ImagVideoProcessorBase`].
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# We can't instantiate directly the base class *VideoProcessorBase* so let's show the examples on a
|
||||
# derived class: *LlavaOnevisionVideoProcessor*
|
||||
video_processor = LlavaOnevisionVideoProcessor.from_pretrained(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
|
||||
) # Download video_processing_config from huggingface.co and cache.
|
||||
video_processor = LlavaOnevisionVideoProcessor.from_pretrained(
|
||||
"./test/saved_model/"
|
||||
) # E.g. video processor (or model) was saved using *save_pretrained('./test/saved_model/')*
|
||||
video_processor = LlavaOnevisionVideoProcessor.from_pretrained("./test/saved_model/preprocessor_config.json")
|
||||
video_processor = LlavaOnevisionVideoProcessor.from_pretrained(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf", do_normalize=False, foo=False
|
||||
)
|
||||
assert video_processor.do_normalize is False
|
||||
video_processor, unused_kwargs = LlavaOnevisionVideoProcessor.from_pretrained(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf", do_normalize=False, foo=False, return_unused_kwargs=True
|
||||
)
|
||||
assert video_processor.do_normalize is False
|
||||
assert unused_kwargs == {"foo": False}
|
||||
```"""
|
||||
kwargs["cache_dir"] = cache_dir
|
||||
kwargs["force_download"] = force_download
|
||||
kwargs["local_files_only"] = local_files_only
|
||||
kwargs["revision"] = revision
|
||||
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if token is not None:
|
||||
raise ValueError(
|
||||
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
||||
)
|
||||
token = use_auth_token
|
||||
|
||||
if token is not None:
|
||||
kwargs["token"] = token
|
||||
|
||||
video_processor_dict, kwargs = cls.get_video_processor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
return cls.from_dict(video_processor_dict, **kwargs)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save an video processor object to the directory `save_directory`, so that it can be re-loaded using the
|
||||
[`~video_processing_utils.VideoProcessorBase.from_pretrained`] class method.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the video processor JSON file will be saved (will be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if kwargs.get("token", None) is not None:
|
||||
raise ValueError(
|
||||
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
||||
)
|
||||
kwargs["token"] = use_auth_token
|
||||
|
||||
if os.path.isfile(save_directory):
|
||||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = self._create_repo(repo_id, **kwargs)
|
||||
files_timestamps = self._get_files_timestamps(save_directory)
|
||||
|
||||
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
||||
# loaded from the Hub.
|
||||
if self._auto_class is not None:
|
||||
custom_object_save(self, save_directory, config=self)
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_video_processor_file = os.path.join(save_directory, VIDEO_PROCESSOR_NAME)
|
||||
|
||||
self.to_json_file(output_video_processor_file)
|
||||
logger.info(f"Video processor saved in {output_video_processor_file}")
|
||||
|
||||
if push_to_hub:
|
||||
self._upload_modified_files(
|
||||
save_directory,
|
||||
repo_id,
|
||||
files_timestamps,
|
||||
commit_message=commit_message,
|
||||
token=kwargs.get("token"),
|
||||
)
|
||||
|
||||
return [output_video_processor_file]
|
||||
|
||||
@classmethod
|
||||
def get_video_processor_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
|
||||
video processor of type [`~video_processing_utils.VideoProcessorBase`] using `from_dict`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
|
||||
Returns:
|
||||
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the video processor object.
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if token is not None:
|
||||
raise ValueError(
|
||||
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
||||
)
|
||||
token = use_auth_token
|
||||
|
||||
user_agent = {"file_type": "video processor", "from_auto_class": from_auto_class}
|
||||
if from_pipeline is not None:
|
||||
user_agent["using_pipeline"] = from_pipeline
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
resolved_video_processor_file = pretrained_model_name_or_path
|
||||
is_local = True
|
||||
elif is_remote_url(pretrained_model_name_or_path):
|
||||
video_processor_file = pretrained_model_name_or_path
|
||||
resolved_video_processor_file = download_url(pretrained_model_name_or_path)
|
||||
else:
|
||||
try:
|
||||
# Try to load with a new config name first and if not successfull try with
|
||||
# the old file name. In case we can load with old name only, raise a deprecation warning
|
||||
# Deprecated until v5.0
|
||||
video_processor_file = VIDEO_PROCESSOR_NAME
|
||||
resolved_video_processor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
video_processor_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
except EnvironmentError:
|
||||
video_processor_file = "preprocessor_config.json"
|
||||
resolved_video_processor_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
video_processor_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
logger.warning_once(
|
||||
"You have video processor config saved in `preprocessor.json` file which is deprecated. "
|
||||
"Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename "
|
||||
"the file or load and save the processor back which renames it automatically. "
|
||||
"Loading from `preprocessor.json` will be removed in v5.0."
|
||||
)
|
||||
except EnvironmentError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
||||
# the original exception.
|
||||
raise
|
||||
except Exception:
|
||||
# For any other exception, we throw a generic error.
|
||||
raise EnvironmentError(
|
||||
f"Can't load video processor for '{pretrained_model_name_or_path}'. If you were trying to load"
|
||||
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||
f" directory containing a {VIDEO_PROCESSOR_NAME} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load video_processor dict
|
||||
with open(resolved_video_processor_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
video_processor_dict = json.loads(text)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise EnvironmentError(
|
||||
f"It looks like the config file at '{resolved_video_processor_file}' is not a valid JSON file."
|
||||
)
|
||||
|
||||
if is_local:
|
||||
logger.info(f"loading configuration file {resolved_video_processor_file}")
|
||||
else:
|
||||
logger.info(
|
||||
f"loading configuration file {video_processor_file} from cache at {resolved_video_processor_file}"
|
||||
)
|
||||
|
||||
if not is_local:
|
||||
if "auto_map" in video_processor_dict:
|
||||
video_processor_dict["auto_map"] = add_model_info_to_auto_map(
|
||||
video_processor_dict["auto_map"], pretrained_model_name_or_path
|
||||
)
|
||||
if "custom_pipelines" in video_processor_dict:
|
||||
video_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
|
||||
video_processor_dict["custom_pipelines"], pretrained_model_name_or_path
|
||||
)
|
||||
return video_processor_dict, kwargs
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, video_processor_dict: Dict[str, Any], **kwargs):
|
||||
"""
|
||||
Instantiates a type of [`~video_processing_utils.VideoProcessorBase`] from a Python dictionary of parameters.
|
||||
|
||||
Args:
|
||||
video_processor_dict (`Dict[str, Any]`):
|
||||
Dictionary that will be used to instantiate the video processor object. Such a dictionary can be
|
||||
retrieved from a pretrained checkpoint by leveraging the
|
||||
[`~video_processing_utils.VideoProcessorBase.to_dict`] method.
|
||||
kwargs (`Dict[str, Any]`):
|
||||
Additional parameters from which to initialize the video processor object.
|
||||
|
||||
Returns:
|
||||
[`~video_processing_utils.VideoProcessorBase`]: The video processor object instantiated from those
|
||||
parameters.
|
||||
"""
|
||||
video_processor_dict = video_processor_dict.copy()
|
||||
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||
|
||||
# The `size` parameter is a dict and was previously an int or tuple in feature extractors.
|
||||
# We set `size` here directly to the `video_processor_dict` so that it is converted to the appropriate
|
||||
# dict within the video processor and isn't overwritten if `size` is passed in as a kwarg.
|
||||
if "size" in kwargs and "size" in video_processor_dict:
|
||||
video_processor_dict["size"] = kwargs.pop("size")
|
||||
if "crop_size" in kwargs and "crop_size" in video_processor_dict:
|
||||
video_processor_dict["crop_size"] = kwargs.pop("crop_size")
|
||||
|
||||
video_processor = cls(**video_processor_dict)
|
||||
|
||||
# Update video_processor with kwargs if needed
|
||||
to_remove = []
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(video_processor, key):
|
||||
setattr(video_processor, key, value)
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
kwargs.pop(key, None)
|
||||
|
||||
logger.info(f"Video processor {video_processor}")
|
||||
if return_unused_kwargs:
|
||||
return video_processor, kwargs
|
||||
else:
|
||||
return video_processor
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this video processor instance.
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["video_processor_type"] = self.__class__.__name__
|
||||
|
||||
return output
|
||||
|
||||
def to_json_string(self) -> str:
|
||||
"""
|
||||
Serializes this instance to a JSON string.
|
||||
|
||||
Returns:
|
||||
`str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
|
||||
"""
|
||||
dictionary = self.to_dict()
|
||||
|
||||
for key, value in dictionary.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
dictionary[key] = value.tolist()
|
||||
|
||||
# make sure private name "_processor_class" is correctly
|
||||
# saved as "processor_class"
|
||||
_processor_class = dictionary.pop("_processor_class", None)
|
||||
if _processor_class is not None:
|
||||
dictionary["processor_class"] = _processor_class
|
||||
|
||||
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
"""
|
||||
Save this instance to a JSON file.
|
||||
|
||||
Args:
|
||||
json_file_path (`str` or `os.PathLike`):
|
||||
Path to the JSON file in which this image_processor instance's parameters will be saved.
|
||||
"""
|
||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||
"""
|
||||
Instantiates a video processor of type [`~video_processing_utils.VideoProcessorBase`] from the path to a JSON
|
||||
file of parameters.
|
||||
|
||||
Args:
|
||||
json_file (`str` or `os.PathLike`):
|
||||
Path to the JSON file containing the parameters.
|
||||
|
||||
Returns:
|
||||
A video processor of type [`~video_processing_utils.VideoProcessorBase`]: The video_processor object
|
||||
instantiated from that JSON file.
|
||||
"""
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
video_processor_dict = json.loads(text)
|
||||
return cls(**video_processor_dict)
|
||||
|
||||
@classmethod
|
||||
def register_for_auto_class(cls, auto_class="AutoVideoProcessor"):
|
||||
"""
|
||||
Register this class with a given auto class. This should only be used for custom video processors as the ones
|
||||
in the library are already mapped with `AutoVideoProcessor `.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is experimental and may have some slight breaking changes in the next releases.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
auto_class (`str` or `type`, *optional*, defaults to `"AutoVideoProcessor "`):
|
||||
The auto class to register this new video processor with.
|
||||
"""
|
||||
if not isinstance(auto_class, str):
|
||||
auto_class = auto_class.__name__
|
||||
|
||||
import transformers.models.auto as auto_module
|
||||
|
||||
if not hasattr(auto_module, auto_class):
|
||||
raise ValueError(f"{auto_class} is not a valid auto class.")
|
||||
|
||||
cls._auto_class = auto_class
|
||||
|
||||
def fetch_videos(self, video_url_or_urls: Union[str, List[str]]):
|
||||
"""
|
||||
Convert a single or a list of urls into the corresponding `np.array` objects.
|
||||
|
||||
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
|
||||
returned.
|
||||
"""
|
||||
if isinstance(video_url_or_urls, list):
|
||||
return [self.fetch_videos(x) for x in video_url_or_urls]
|
||||
elif isinstance(video_url_or_urls, str):
|
||||
return load_video(video_url_or_urls)
|
||||
else:
|
||||
raise TypeError(f"only a single or a list of entries is supported but got type={type(video_url_or_urls)}")
|
||||
|
||||
|
||||
BaseVideoProcessor.push_to_hub = copy_func(BaseVideoProcessor.push_to_hub)
|
||||
if BaseVideoProcessor.push_to_hub.__doc__ is not None:
|
||||
BaseVideoProcessor.push_to_hub.__doc__ = BaseVideoProcessor.push_to_hub.__doc__.format(
|
||||
object="video processor", object_class="AutoVideoProcessor", object_files="video processor file"
|
||||
)
|
717
src/transformers/video_utils.py
Normal file
717
src/transformers/video_utils.py
Normal file
@ -0,0 +1,717 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import redirect_stdout
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from .image_transforms import PaddingMode, to_channel_dimension_format
|
||||
from .image_utils import ChannelDimension, infer_channel_dimension_format, is_valid_image
|
||||
from .utils import (
|
||||
is_av_available,
|
||||
is_cv2_available,
|
||||
is_decord_available,
|
||||
is_numpy_array,
|
||||
is_torch_available,
|
||||
is_torch_tensor,
|
||||
is_torchvision_available,
|
||||
is_vision_available,
|
||||
is_yt_dlp_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision import io as torchvision_io
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
VideoInput = Union[
|
||||
List["PIL.Image.Image"],
|
||||
"np.ndarray",
|
||||
"torch.Tensor",
|
||||
List["np.ndarray"],
|
||||
List["torch.Tensor"],
|
||||
List[List["PIL.Image.Image"]],
|
||||
List[List["np.ndarrray"]],
|
||||
List[List["torch.Tensor"]],
|
||||
] # noqa
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoMetadata:
|
||||
total_num_frames: int
|
||||
fps: float
|
||||
duration: float
|
||||
video_backend: str
|
||||
|
||||
|
||||
def is_valid_video_frame(frame):
|
||||
return isinstance(frame, PIL.Image.Image) or (
|
||||
(is_numpy_array(frame) or is_torch_tensor(frame)) and frame.ndim == 3
|
||||
)
|
||||
|
||||
|
||||
def is_valid_video(video):
|
||||
if not isinstance(video, (list, tuple)):
|
||||
return (is_numpy_array(video) or is_torch_tensor(video)) and video.ndim == 4
|
||||
return all(is_valid_video_frame(frame) for frame in video)
|
||||
|
||||
|
||||
def valid_videos(videos):
|
||||
# If we have a list of videos, it could be either one video as list of frames or a batch
|
||||
if isinstance(videos, (list, tuple)):
|
||||
for video_or_frame in videos:
|
||||
if not (is_valid_video(video_or_frame) or is_valid_video_frame(video_or_frame)):
|
||||
return False
|
||||
# If not a list, then we have a single 4D video or 5D batched tensor
|
||||
elif not is_valid_video(videos) or videos.ndim == 5:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_batched_video(videos):
|
||||
if isinstance(videos, (list, tuple)):
|
||||
return is_valid_video(videos[0])
|
||||
elif (is_numpy_array(videos) or is_torch_tensor(videos)) and videos.ndim == 5:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_scaled_video(video: np.ndarray) -> bool:
|
||||
"""
|
||||
Checks to see whether the pixel values have already been rescaled to [0, 1].
|
||||
"""
|
||||
# It's possible the video has pixel values in [0, 255] but is of floating type
|
||||
return np.min(video) >= 0 and np.max(video) <= 1
|
||||
|
||||
|
||||
def convert_pil_frames_to_video(videos: List[VideoInput]) -> List[Union["np.ndarray", "torch.Tensor"]]:
|
||||
"""
|
||||
Given a batch of videos, converts each video to a 4D array. If video is already in array type,
|
||||
it is simply returned. We assume that all inputs in the list are in the same format, based on the type of the first element.
|
||||
|
||||
Args:
|
||||
videos (`VideoInput`):
|
||||
Video inputs to turn into a list of videos.
|
||||
"""
|
||||
|
||||
if not isinstance(videos[0], (list, tuple)):
|
||||
return videos
|
||||
|
||||
video_converted = []
|
||||
for video in videos:
|
||||
video = [np.array(frame) for frame in video]
|
||||
video = np.stack(video)
|
||||
video_converted.append(video)
|
||||
return video_converted
|
||||
|
||||
|
||||
def make_batched_videos(videos) -> List[Union["np.ndarray", "torch.Tensor"]]:
|
||||
"""
|
||||
Ensure that the input is a list of videos. If the input is a single video, it is converted to a list of length 1.
|
||||
If the input is a batch of videos, it is converted to a list of 4D video arrays. Videos passed as list `PIL.Image`
|
||||
frames are converted to 4D arrays.
|
||||
|
||||
We assume that all inputs in the list are in the same format, based on the type of the first element.
|
||||
|
||||
Args:
|
||||
videos (`VideoInput`):
|
||||
Video inputs to turn into a list of videos.
|
||||
"""
|
||||
if not valid_videos:
|
||||
raise ValueError(
|
||||
f"Invalid video input. Expected either a list of video frames or an input of 4 or 5 dimensions, but got"
|
||||
f" type {type(videos)}."
|
||||
)
|
||||
|
||||
if is_batched_video(videos):
|
||||
pass
|
||||
elif is_valid_video(videos):
|
||||
videos = [videos]
|
||||
# only one frame passed, thus we unsqueeze time dim
|
||||
elif is_valid_image(videos):
|
||||
videos = [np.array(videos)[None, ...]]
|
||||
# nested batch so we need to unflatten
|
||||
elif isinstance(videos[0], (list, tuple)) and is_valid_video(videos[0][0]):
|
||||
return [video for sublist in videos for video in sublist]
|
||||
return convert_pil_frames_to_video(videos)
|
||||
|
||||
|
||||
def get_video_size(video: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns the (height, width) dimensions of the video.
|
||||
|
||||
Args:
|
||||
video (`np.ndarray`):
|
||||
The video to get the dimensions of.
|
||||
channel_dim (`ChannelDimension`, *optional*):
|
||||
Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the video.
|
||||
|
||||
Returns:
|
||||
A tuple of the video's height and width.
|
||||
"""
|
||||
if channel_dim is None:
|
||||
channel_dim = infer_channel_dimension_format(video)
|
||||
|
||||
if channel_dim == ChannelDimension.FIRST:
|
||||
return video.shape[-2], video.shape[-1]
|
||||
elif channel_dim == ChannelDimension.LAST:
|
||||
return video.shape[-3], video.shape[-2]
|
||||
else:
|
||||
raise ValueError(f"Unsupported data format: {channel_dim}")
|
||||
|
||||
|
||||
def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = None):
|
||||
"""
|
||||
Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames`
|
||||
when loading a video.
|
||||
|
||||
Args:
|
||||
total_num_frames (`int`):
|
||||
Total number of frames that a video has.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. If not specified, all frames are sampled.
|
||||
|
||||
Returns:
|
||||
np.ndarray: np array of frame indices that will be sampled.
|
||||
"""
|
||||
if num_frames is not None:
|
||||
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
|
||||
else:
|
||||
indices = np.arange(0, total_num_frames).astype(int)
|
||||
return indices
|
||||
|
||||
|
||||
def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
|
||||
"""
|
||||
A default sampling function that replicates the logic used in get_uniform_frame_indices,
|
||||
while optionally handling `fps` if `num_frames` is not provided.
|
||||
|
||||
Args:
|
||||
metadata (`VideoMetadata`):
|
||||
`VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly.
|
||||
fps (`int`, *optional*):
|
||||
Desired frames per second. Takes priority over num_frames if both are provided.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: Array of frame indices to sample.
|
||||
"""
|
||||
total_num_frames = metadata.total_num_frames
|
||||
video_fps = metadata.fps
|
||||
|
||||
# If num_frames is not given but fps is, calculate num_frames from fps
|
||||
if num_frames is None and fps is not None:
|
||||
num_frames = int(total_num_frames / video_fps * fps)
|
||||
if num_frames > total_num_frames:
|
||||
raise ValueError(
|
||||
f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
|
||||
f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
|
||||
)
|
||||
|
||||
if num_frames is not None:
|
||||
indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
|
||||
else:
|
||||
indices = np.arange(0, total_num_frames, dtype=int)
|
||||
return indices
|
||||
|
||||
|
||||
def read_video_opencv(
|
||||
video_path: str,
|
||||
sample_indices_fn: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode a video using the OpenCV backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_indices_fn (`Callable`):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import cv2
|
||||
requires_backends(read_video_opencv, ["cv2"])
|
||||
import cv2
|
||||
|
||||
video = cv2.VideoCapture(video_path)
|
||||
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
video_fps = video.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv"
|
||||
)
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
index = 0
|
||||
frames = []
|
||||
while video.isOpened():
|
||||
success, frame = video.read()
|
||||
if not success:
|
||||
break
|
||||
if index in indices:
|
||||
height, width, channel = frame.shape
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frames.append(frame[0:height, 0:width, 0:channel])
|
||||
if success:
|
||||
index += 1
|
||||
if index >= total_num_frames:
|
||||
break
|
||||
|
||||
video.release()
|
||||
metadata.frames_indices = indices
|
||||
return np.stack(frames), metadata
|
||||
|
||||
|
||||
def read_video_decord(
|
||||
video_path: str,
|
||||
sample_indices_fn: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode a video using the Decord backend.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import from decord
|
||||
requires_backends(read_video_decord, ["decord"])
|
||||
from decord import VideoReader, cpu
|
||||
|
||||
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
|
||||
video_fps = vr.get_avg_fps()
|
||||
total_num_frames = len(vr)
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord"
|
||||
)
|
||||
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
frames = vr.get_batch(indices).asnumpy()
|
||||
metadata.frames_indices = indices
|
||||
return frames, metadata
|
||||
|
||||
|
||||
def read_video_pyav(
|
||||
video_path: str,
|
||||
sample_indices_fn: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode the video with PyAV decoder.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
# Lazy import av
|
||||
requires_backends(read_video_pyav, ["av"])
|
||||
import av
|
||||
|
||||
container = av.open(video_path)
|
||||
total_num_frames = container.streams.video[0].frames
|
||||
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav"
|
||||
)
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
frames = []
|
||||
container.seek(0)
|
||||
end_index = indices[-1]
|
||||
for i, frame in enumerate(container.decode(video=0)):
|
||||
if i > end_index:
|
||||
break
|
||||
if i >= 0 and i in indices:
|
||||
frames.append(frame)
|
||||
|
||||
video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||
metadata.frames_indices = indices
|
||||
return video, metadata
|
||||
|
||||
|
||||
def read_video_torchvision(
|
||||
video_path: str,
|
||||
sample_indices_fn: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Decode the video with torchvision decoder.
|
||||
|
||||
Args:
|
||||
video_path (`str`):
|
||||
Path to the video file.
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniform sampling with fps is performed.
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- `VideoMetadata` object.
|
||||
"""
|
||||
video, _, info = torchvision_io.read_video(
|
||||
video_path,
|
||||
start_pts=0.0,
|
||||
end_pts=None,
|
||||
pts_unit="sec",
|
||||
output_format="THWC",
|
||||
)
|
||||
video_fps = info["video_fps"]
|
||||
total_num_frames = video.size(0)
|
||||
duration = total_num_frames / video_fps if video_fps else 0
|
||||
metadata = VideoMetadata(
|
||||
total_num_frames=int(total_num_frames),
|
||||
fps=float(video_fps),
|
||||
duration=float(duration),
|
||||
video_backend="torchvision",
|
||||
)
|
||||
|
||||
indices = sample_indices_fn(metadata=metadata, **kwargs)
|
||||
|
||||
video = video[indices].contiguous().numpy()
|
||||
metadata.frames_indices = indices
|
||||
return video, metadata
|
||||
|
||||
|
||||
VIDEO_DECODERS = {
|
||||
"decord": read_video_decord,
|
||||
"opencv": read_video_opencv,
|
||||
"pyav": read_video_pyav,
|
||||
"torchvision": read_video_torchvision,
|
||||
}
|
||||
|
||||
|
||||
def load_video(
|
||||
video: Union[str, "VideoInput"],
|
||||
num_frames: Optional[int] = None,
|
||||
fps: Optional[int] = None,
|
||||
backend: str = "pyav",
|
||||
sample_indices_fn: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
) -> np.array:
|
||||
"""
|
||||
Loads `video` to a numpy array.
|
||||
|
||||
Args:
|
||||
video (`str` or `VideoInput`):
|
||||
The video to convert to the numpy array format. Can be a link to video or local path.
|
||||
num_frames (`int`, *optional*):
|
||||
Number of frames to sample uniformly. If not passed, the whole video is loaded.
|
||||
fps (`int`, *optional*):
|
||||
Number of frames to sample per second. Should be passed only when `num_frames=None`.
|
||||
If not specified and `num_frames==None`, all frames are sampled.
|
||||
backend (`str`, *optional*, defaults to `"pyav"`):
|
||||
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav".
|
||||
sample_indices_fn (`Callable`, *optional*):
|
||||
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
|
||||
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
|
||||
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
|
||||
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
|
||||
indices at which the video should be sampled. For example:
|
||||
|
||||
Example:
|
||||
def sample_indices_fn(metadata, **kwargs):
|
||||
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
|
||||
|
||||
Returns:
|
||||
Tuple[`np.array`, Dict]: A tuple containing:
|
||||
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
|
||||
- Metadata dictionary.
|
||||
"""
|
||||
|
||||
# If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
|
||||
if fps is not None and num_frames is not None and sample_indices_fn is None:
|
||||
raise ValueError(
|
||||
"`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
|
||||
)
|
||||
|
||||
# If user didn't pass a sampling function, create one on the fly with default logic
|
||||
if sample_indices_fn is None:
|
||||
|
||||
def sample_indices_fn_func(metadata, **fn_kwargs):
|
||||
return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
|
||||
|
||||
sample_indices_fn = sample_indices_fn_func
|
||||
|
||||
if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
|
||||
if not is_yt_dlp_available():
|
||||
raise ImportError("To load a video from YouTube url you have to install `yt_dlp` first.")
|
||||
# Lazy import from yt_dlp
|
||||
requires_backends(load_video, ["yt_dlp"])
|
||||
from yt_dlp import YoutubeDL
|
||||
|
||||
buffer = BytesIO()
|
||||
with redirect_stdout(buffer), YoutubeDL() as f:
|
||||
f.download([video])
|
||||
bytes_obj = buffer.getvalue()
|
||||
file_obj = BytesIO(bytes_obj)
|
||||
elif video.startswith("http://") or video.startswith("https://"):
|
||||
file_obj = BytesIO(requests.get(video).content)
|
||||
elif os.path.isfile(video):
|
||||
file_obj = video
|
||||
elif is_valid_image(video) or (isinstance(video, (list, tuple)) and is_valid_image(video[0])):
|
||||
file_obj = None
|
||||
else:
|
||||
raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
|
||||
|
||||
# can also load with decord, but not cv2/torchvision
|
||||
# both will fail in case of url links
|
||||
video_is_url = video.startswith("http://") or video.startswith("https://")
|
||||
if video_is_url and backend in ["opencv", "torchvision"]:
|
||||
raise ValueError(
|
||||
"If you are trying to load a video from URL, you can decode the video only with `pyav` or `decord` as backend"
|
||||
)
|
||||
|
||||
if file_obj is None:
|
||||
return video
|
||||
|
||||
if (
|
||||
(not is_decord_available() and backend == "decord")
|
||||
or (not is_av_available() and backend == "pyav")
|
||||
or (not is_cv2_available() and backend == "opencv")
|
||||
or (not is_torchvision_available() and backend == "torchvision")
|
||||
):
|
||||
raise ImportError(
|
||||
f"You chose backend={backend} for loading the video but the required library is not found in your environment "
|
||||
f"Make sure to install {backend} before loading the video."
|
||||
)
|
||||
|
||||
video_decoder = VIDEO_DECODERS[backend]
|
||||
video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
|
||||
return video, metadata
|
||||
|
||||
|
||||
def convert_to_rgb(
|
||||
video: np.array,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.array:
|
||||
"""
|
||||
Convert video to RGB by blending the transparency layer if it's in RGBA format, otherwise simply returns it.
|
||||
|
||||
Args:
|
||||
video (`np.array`):
|
||||
The video to convert.
|
||||
data_format (`ChannelDimension`, *optional*):
|
||||
The channel dimension format of the output video. If unset, will use the inferred format from the input.
|
||||
input_data_format (`ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input video. If unset, will use the inferred format from the input.
|
||||
"""
|
||||
if not isinstance(video, np.ndarray):
|
||||
raise ValueError(f"Video has to be a numpy array to convert to RGB format, but found {type(video)}")
|
||||
|
||||
# np.array usually comes with ChannelDimension.LAST so leet's convert it
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(video)
|
||||
video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_channel_dim=input_data_format)
|
||||
|
||||
# 3 channels for RGB already
|
||||
if video.shape[-3] == 3:
|
||||
return video
|
||||
|
||||
# Grayscale video so we repeat it 3 times for each channel
|
||||
if video.shape[-3] == 1:
|
||||
return video.repeat(3, -3)
|
||||
|
||||
if not (video[..., 3, :, :] < 255).any():
|
||||
return video
|
||||
|
||||
# There is a transparency layer, blend it with a white background.
|
||||
# Calculate the alpha proportion for blending.
|
||||
alpha = video[..., 3, :, :] / 255.0
|
||||
video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., 3, :, :]
|
||||
return video
|
||||
|
||||
|
||||
def pad(
|
||||
video: np.ndarray,
|
||||
padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
|
||||
mode: PaddingMode = PaddingMode.CONSTANT,
|
||||
constant_values: Union[float, Iterable[float]] = 0.0,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Pads the `video` with the specified (height, width) `padding` and `mode`.
|
||||
|
||||
Args:
|
||||
video (`np.ndarray`):
|
||||
The video to pad.
|
||||
padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
|
||||
Padding to apply to the edges of the height, width axes. Can be one of three formats:
|
||||
- `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
|
||||
- `((before, after),)` yields same before and after pad for height and width.
|
||||
- `(pad,)` or int is a shortcut for before = after = pad width for all axes.
|
||||
mode (`PaddingMode`):
|
||||
The padding mode to use. Can be one of:
|
||||
- `"constant"`: pads with a constant value.
|
||||
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
|
||||
vector along each axis.
|
||||
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
|
||||
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
|
||||
constant_values (`float` or `Iterable[float]`, *optional*):
|
||||
The value to use for the padding if `mode` is `"constant"`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output video. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
|
||||
If unset, will use same as the input video.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input video. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
|
||||
If unset, will use the inferred format of the input video.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The padded video.
|
||||
|
||||
"""
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(video)
|
||||
|
||||
def _expand_for_data_format(values):
|
||||
"""
|
||||
Convert values to be in the format expected by np.pad based on the data format.
|
||||
"""
|
||||
if isinstance(values, (int, float)):
|
||||
values = ((values, values), (values, values))
|
||||
elif isinstance(values, tuple) and len(values) == 1:
|
||||
values = ((values[0], values[0]), (values[0], values[0]))
|
||||
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
|
||||
values = (values, values)
|
||||
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
|
||||
values = values
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {values}")
|
||||
|
||||
# add 0 for channel dimension
|
||||
values = (
|
||||
((0, 0), (0, 0), *values) if input_data_format == ChannelDimension.FIRST else ((0, 0), *values, (0, 0))
|
||||
)
|
||||
|
||||
# Add additional padding if there's a batch dimension
|
||||
values = (0, *values) if video.ndim == 5 else values
|
||||
return values
|
||||
|
||||
padding_map = {
|
||||
PaddingMode.CONSTANT: "constant",
|
||||
PaddingMode.REFLECT: "reflect",
|
||||
PaddingMode.REPLICATE: "replicate",
|
||||
PaddingMode.SYMMETRIC: "symmetric",
|
||||
}
|
||||
padding = _expand_for_data_format(padding)
|
||||
|
||||
pad_kwargs = {}
|
||||
if mode not in padding_map:
|
||||
raise ValueError(f"Invalid padding mode: {mode}")
|
||||
elif mode == PaddingMode.CONSTANT:
|
||||
pad_kwargs["constant_values"] = _expand_for_data_format(constant_values)
|
||||
|
||||
video = np.pad(video, padding, mode=padding_map[mode], **pad_kwargs)
|
||||
video = to_channel_dimension_format(video, data_format, input_data_format) if data_format is not None else video
|
||||
return video
|
||||
|
||||
|
||||
def group_videos_by_shape(
|
||||
videos: List["torch.Tensor"],
|
||||
) -> Tuple[Dict[Tuple[int, int], List["torch.Tensor"]], Dict[int, Tuple[Tuple[int, int], int]]]:
|
||||
"""
|
||||
Groups videos by shape.
|
||||
Returns a dictionary with the shape as key and a list of videos with that shape as value,
|
||||
and a dictionary with the index of the video in the original list as key and the shape and index in the grouped list as value.
|
||||
"""
|
||||
grouped_videos = {}
|
||||
grouped_videos_index = {}
|
||||
for i, video in enumerate(videos):
|
||||
shape = video.shape[-2::]
|
||||
if shape not in grouped_videos:
|
||||
grouped_videos[shape] = []
|
||||
grouped_videos[shape].append(video)
|
||||
grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1)
|
||||
# stack videos with the same shape
|
||||
grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()}
|
||||
return grouped_videos, grouped_videos_index
|
||||
|
||||
|
||||
def reorder_videos(
|
||||
processed_videos: Dict[Tuple[int, int], "torch.Tensor"], grouped_videos_index: Dict[int, Tuple[int, int]]
|
||||
) -> List["torch.Tensor"]:
|
||||
"""
|
||||
Reconstructs a list of videos in the original order.
|
||||
"""
|
||||
return [
|
||||
processed_videos[grouped_videos_index[i][0]][grouped_videos_index[i][1]]
|
||||
for i in range(len(grouped_videos_index))
|
||||
]
|
@ -73,6 +73,19 @@ class AutoImageProcessorTest(unittest.TestCase):
|
||||
config = AutoImageProcessor.from_pretrained(tmpdirname)
|
||||
self.assertIsInstance(config, CLIPImageProcessor)
|
||||
|
||||
def test_image_processor_from_new_filename(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
processor_tmpfile = Path(tmpdirname) / "preprocessor_config.json"
|
||||
config_tmpfile = Path(tmpdirname) / "config.json"
|
||||
json.dump(
|
||||
{"image_processor_type": "CLIPImageProcessor", "processor_class": "CLIPProcessor"},
|
||||
open(processor_tmpfile, "w"),
|
||||
)
|
||||
json.dump({"model_type": "clip"}, open(config_tmpfile, "w"))
|
||||
|
||||
config = AutoImageProcessor.from_pretrained(tmpdirname)
|
||||
self.assertIsInstance(config, CLIPImageProcessor)
|
||||
|
||||
def test_image_processor_from_local_directory_from_config(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model_config = CLIPConfig()
|
||||
|
@ -40,7 +40,11 @@ from transformers import (
|
||||
)
|
||||
from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test
|
||||
from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE
|
||||
from transformers.utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, is_tokenizers_available
|
||||
from transformers.utils import (
|
||||
FEATURE_EXTRACTOR_NAME,
|
||||
PROCESSOR_NAME,
|
||||
is_tokenizers_available,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent.parent / "utils"))
|
||||
@ -395,6 +399,13 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-convnext")
|
||||
self.assertEqual(processor.__class__.__name__, "ConvNextImageProcessor")
|
||||
|
||||
def test_auto_processor_save_load(self):
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
processor.save_pretrained(tmp_dir)
|
||||
second_processor = AutoProcessor.from_pretrained(tmp_dir)
|
||||
self.assertEqual(second_processor.__class__.__name__, processor.__class__.__name__)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ProcessorPushToHubTester(unittest.TestCase):
|
||||
|
252
tests/models/auto/test_video_processing_auto.py
Normal file
252
tests/models/auto/test_video_processing_auto.py
Normal file
@ -0,0 +1,252 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
VIDEO_PROCESSOR_MAPPING,
|
||||
AutoConfig,
|
||||
AutoVideoProcessor,
|
||||
LlavaOnevisionConfig,
|
||||
LlavaOnevisionVideoProcessor,
|
||||
)
|
||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_torch
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_configuration import CustomConfig # noqa E402
|
||||
from test_module.custom_video_processing import CustomVideoProcessor # noqa E402
|
||||
|
||||
|
||||
@require_torch
|
||||
class AutoVideoProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
transformers.dynamic_module_utils.TIME_OUT_REMOTE_CODE = 0
|
||||
|
||||
def test_video_processor_from_model_shortcut(self):
|
||||
config = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
|
||||
self.assertIsInstance(config, LlavaOnevisionVideoProcessor)
|
||||
|
||||
def test_video_processor_from_local_directory_from_key(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
processor_tmpfile = Path(tmpdirname) / "video_preprocessor_config.json"
|
||||
config_tmpfile = Path(tmpdirname) / "config.json"
|
||||
json.dump(
|
||||
{
|
||||
"video_processor_type": "LlavaOnevisionVideoProcessor",
|
||||
"processor_class": "LlavaOnevisionProcessor",
|
||||
},
|
||||
open(processor_tmpfile, "w"),
|
||||
)
|
||||
json.dump({"model_type": "llava_onevision"}, open(config_tmpfile, "w"))
|
||||
|
||||
config = AutoVideoProcessor.from_pretrained(tmpdirname)
|
||||
self.assertIsInstance(config, LlavaOnevisionVideoProcessor)
|
||||
|
||||
def test_video_processor_from_local_directory_from_preprocessor_key(self):
|
||||
# Ensure we can load the image processor from the feature extractor config
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
processor_tmpfile = Path(tmpdirname) / "preprocessor_config.json"
|
||||
config_tmpfile = Path(tmpdirname) / "config.json"
|
||||
json.dump(
|
||||
{
|
||||
"video_processor_type": "LlavaOnevisionVideoProcessor",
|
||||
"processor_class": "LlavaOnevisionProcessor",
|
||||
},
|
||||
open(processor_tmpfile, "w"),
|
||||
)
|
||||
json.dump({"model_type": "llava_onevision"}, open(config_tmpfile, "w"))
|
||||
|
||||
config = AutoVideoProcessor.from_pretrained(tmpdirname)
|
||||
self.assertIsInstance(config, LlavaOnevisionVideoProcessor)
|
||||
|
||||
def test_video_processor_from_local_directory_from_config(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model_config = LlavaOnevisionConfig()
|
||||
|
||||
# Create a dummy config file with image_proceesor_type
|
||||
processor_tmpfile = Path(tmpdirname) / "video_preprocessor_config.json"
|
||||
config_tmpfile = Path(tmpdirname) / "config.json"
|
||||
json.dump(
|
||||
{
|
||||
"video_processor_type": "LlavaOnevisionVideoProcessor",
|
||||
"processor_class": "LlavaOnevisionProcessor",
|
||||
},
|
||||
open(processor_tmpfile, "w"),
|
||||
)
|
||||
json.dump({"model_type": "llava_onevision"}, open(config_tmpfile, "w"))
|
||||
|
||||
# remove video_processor_type to make sure config.json alone is enough to load image processor locally
|
||||
config_dict = AutoVideoProcessor.from_pretrained(tmpdirname).to_dict()
|
||||
|
||||
config_dict.pop("video_processor_type")
|
||||
config = LlavaOnevisionVideoProcessor(**config_dict)
|
||||
|
||||
# save in new folder
|
||||
model_config.save_pretrained(tmpdirname)
|
||||
config.save_pretrained(tmpdirname)
|
||||
|
||||
config = AutoVideoProcessor.from_pretrained(tmpdirname)
|
||||
|
||||
# make sure private variable is not incorrectly saved
|
||||
dict_as_saved = json.loads(config.to_json_string())
|
||||
self.assertTrue("_processor_class" not in dict_as_saved)
|
||||
|
||||
self.assertIsInstance(config, LlavaOnevisionVideoProcessor)
|
||||
|
||||
def test_video_processor_from_local_file(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
processor_tmpfile = Path(tmpdirname) / "video_preprocessor_config.json"
|
||||
json.dump(
|
||||
{
|
||||
"video_processor_type": "LlavaOnevisionVideoProcessor",
|
||||
"processor_class": "LlavaOnevisionProcessor",
|
||||
},
|
||||
open(processor_tmpfile, "w"),
|
||||
)
|
||||
|
||||
config = AutoVideoProcessor.from_pretrained(processor_tmpfile)
|
||||
self.assertIsInstance(config, LlavaOnevisionVideoProcessor)
|
||||
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError,
|
||||
"llava-hf/llava-doesnt-exist is not a local folder and is not a valid model identifier",
|
||||
):
|
||||
_ = AutoVideoProcessor.from_pretrained("llava-hf/llava-doesnt-exist")
|
||||
|
||||
def test_revision_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||
):
|
||||
_ = AutoVideoProcessor.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||
|
||||
def test_video_processor_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError,
|
||||
"hf-internal-testing/config-no-model does not appear to have a file named preprocessor_config.json.",
|
||||
):
|
||||
_ = AutoVideoProcessor.from_pretrained("hf-internal-testing/config-no-model")
|
||||
|
||||
def test_from_pretrained_dynamic_video_processor(self):
|
||||
# If remote code is not set, we will time out when asking whether to load the model.
|
||||
with self.assertRaises(ValueError):
|
||||
video_processor = AutoVideoProcessor.from_pretrained("hf-internal-testing/test_dynamic_video_processor")
|
||||
# If remote code is disabled, we can't load this config.
|
||||
with self.assertRaises(ValueError):
|
||||
video_processor = AutoVideoProcessor.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_video_processor", trust_remote_code=False
|
||||
)
|
||||
|
||||
video_processor = AutoVideoProcessor.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_video_processor", trust_remote_code=True
|
||||
)
|
||||
self.assertEqual(video_processor.__class__.__name__, "NewVideoProcessor")
|
||||
|
||||
# Test the dynamic module is loaded only once.
|
||||
reloaded_video_processor = AutoVideoProcessor.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_video_processor", trust_remote_code=True
|
||||
)
|
||||
self.assertIs(video_processor.__class__, reloaded_video_processor.__class__)
|
||||
|
||||
# Test image processor can be reloaded.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
video_processor.save_pretrained(tmp_dir)
|
||||
reloaded_video_processor = AutoVideoProcessor.from_pretrained(tmp_dir, trust_remote_code=True)
|
||||
self.assertEqual(reloaded_video_processor.__class__.__name__, "NewVideoProcessor")
|
||||
|
||||
# The image processor file is cached in the snapshot directory. So the module file is not changed after dumping
|
||||
# to a temp dir. Because the revision of the module file is not changed.
|
||||
# Test the dynamic module is loaded only once if the module file is not changed.
|
||||
self.assertIs(video_processor.__class__, reloaded_video_processor.__class__)
|
||||
|
||||
# Test the dynamic module is reloaded if we force it.
|
||||
reloaded_video_processor = AutoVideoProcessor.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_video_processor", trust_remote_code=True, force_download=True
|
||||
)
|
||||
self.assertIsNot(video_processor.__class__, reloaded_video_processor.__class__)
|
||||
|
||||
def test_new_video_processor_registration(self):
|
||||
try:
|
||||
AutoConfig.register("custom", CustomConfig)
|
||||
AutoVideoProcessor.register(CustomConfig, CustomVideoProcessor)
|
||||
# Trying to register something existing in the Transformers library will raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
AutoVideoProcessor.register(LlavaOnevisionConfig, LlavaOnevisionVideoProcessor)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
processor_tmpfile = Path(tmpdirname) / "video_preprocessor_config.json"
|
||||
config_tmpfile = Path(tmpdirname) / "config.json"
|
||||
json.dump(
|
||||
{
|
||||
"video_processor_type": "LlavaOnevisionVideoProcessor",
|
||||
"processor_class": "LlavaOnevisionProcessor",
|
||||
},
|
||||
open(processor_tmpfile, "w"),
|
||||
)
|
||||
json.dump({"model_type": "llava_onevision"}, open(config_tmpfile, "w"))
|
||||
|
||||
video_processor = CustomVideoProcessor.from_pretrained(tmpdirname)
|
||||
|
||||
# Now that the config is registered, it can be used as any other config with the auto-API
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
video_processor.save_pretrained(tmp_dir)
|
||||
new_video_processor = AutoVideoProcessor.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(new_video_processor, CustomVideoProcessor)
|
||||
|
||||
finally:
|
||||
if "custom" in CONFIG_MAPPING._extra_content:
|
||||
del CONFIG_MAPPING._extra_content["custom"]
|
||||
if CustomConfig in VIDEO_PROCESSOR_MAPPING._extra_content:
|
||||
del VIDEO_PROCESSOR_MAPPING._extra_content[CustomConfig]
|
||||
|
||||
def test_from_pretrained_dynamic_video_processor_conflict(self):
|
||||
class NewVideoProcessor(LlavaOnevisionVideoProcessor):
|
||||
is_local = True
|
||||
|
||||
try:
|
||||
AutoConfig.register("custom", CustomConfig)
|
||||
AutoVideoProcessor.register(CustomConfig, NewVideoProcessor)
|
||||
# If remote code is not set, the default is to use local
|
||||
video_processor = AutoVideoProcessor.from_pretrained("hf-internal-testing/test_dynamic_video_processor")
|
||||
self.assertEqual(video_processor.__class__.__name__, "NewVideoProcessor")
|
||||
self.assertTrue(video_processor.is_local)
|
||||
|
||||
# If remote code is disabled, we load the local one.
|
||||
video_processor = AutoVideoProcessor.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_video_processor", trust_remote_code=False
|
||||
)
|
||||
self.assertEqual(video_processor.__class__.__name__, "NewVideoProcessor")
|
||||
self.assertTrue(video_processor.is_local)
|
||||
|
||||
# If remote is enabled, we load from the Hub
|
||||
video_processor = AutoVideoProcessor.from_pretrained(
|
||||
"hf-internal-testing/test_dynamic_video_processor", trust_remote_code=True
|
||||
)
|
||||
self.assertEqual(video_processor.__class__.__name__, "NewVideoProcessor")
|
||||
self.assertTrue(not hasattr(video_processor, "is_local"))
|
||||
|
||||
finally:
|
||||
if "custom" in CONFIG_MAPPING._extra_content:
|
||||
del CONFIG_MAPPING._extra_content["custom"]
|
||||
if CustomConfig in VIDEO_PROCESSOR_MAPPING._extra_content:
|
||||
del VIDEO_PROCESSOR_MAPPING._extra_content[CustomConfig]
|
@ -1,190 +0,0 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import InstructBlipVideoImageProcessor
|
||||
|
||||
|
||||
class InstructBlipVideoProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=5,
|
||||
num_channels=3,
|
||||
image_size=24,
|
||||
min_resolution=30,
|
||||
max_resolution=80,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_normalize=True,
|
||||
image_mean=OPENAI_CLIP_MEAN,
|
||||
image_std=OPENAI_CLIP_STD,
|
||||
do_convert_rgb=True,
|
||||
frames=4,
|
||||
):
|
||||
size = size if size is not None else {"height": 18, "width": 18}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.frames = frames
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.frames, self.num_channels, self.size["height"], self.size["width"]
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
images = prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
# let's simply copy the frames to fake a long video-clip
|
||||
if numpify or torchify:
|
||||
videos = []
|
||||
for image in images:
|
||||
if numpify:
|
||||
video = image[None, ...].repeat(self.frames, 0)
|
||||
else:
|
||||
video = image[None, ...].repeat(self.frames, 1, 1, 1)
|
||||
videos.append(video)
|
||||
else:
|
||||
videos = []
|
||||
for pil_image in images:
|
||||
videos.append([pil_image] * self.frames)
|
||||
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class InstructBlipVideoProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = InstructBlipVideoImageProcessor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = InstructBlipVideoProcessingTester(self)
|
||||
|
||||
@property
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video[0], Image.Image)
|
||||
|
||||
# Test not batched input (pass as `videos` arg to test that ImageProcessor can handle videos in absence of images!)
|
||||
encoded_videos = image_processing(images=video_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_video_shape = (1, 4, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(images=video_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_video_shape = (5, 4, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, np.ndarray)
|
||||
|
||||
# Test not batched input (pass as `videos` arg to test that ImageProcessor can handle videos in absence of images!)
|
||||
encoded_videos = image_processing(images=video_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_video_shape = (1, 4, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(images=video_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_video_shape = (5, 4, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
video_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = image_processing(images=video_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_video_shape = (1, 4, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(images=video_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_video_shape = (5, 4, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
@ -17,8 +17,8 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
@ -28,14 +28,16 @@ if is_vision_available():
|
||||
AutoProcessor,
|
||||
BertTokenizerFast,
|
||||
GPT2Tokenizer,
|
||||
InstructBlipVideoImageProcessor,
|
||||
InstructBlipVideoProcessor,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import InstructBlipVideoVideoProcessor
|
||||
|
||||
|
||||
@require_vision
|
||||
# Copied from tests.models.instructblip.test_processor_instructblip.InstructBlipProcessorTest with InstructBlip->InstructBlipVideo, BlipImageProcessor->InstructBlipVideoImageProcessor
|
||||
@require_torch
|
||||
class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = InstructBlipVideoProcessor
|
||||
|
||||
@ -43,23 +45,23 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def setUpClass(cls):
|
||||
cls.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
image_processor = InstructBlipVideoImageProcessor()
|
||||
video_processor = InstructBlipVideoVideoProcessor()
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-GPT2Model")
|
||||
qformer_tokenizer = BertTokenizerFast.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
processor = InstructBlipVideoProcessor(image_processor, tokenizer, qformer_tokenizer)
|
||||
processor = InstructBlipVideoProcessor(video_processor, tokenizer, qformer_tokenizer)
|
||||
|
||||
processor.save_pretrained(cls.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def get_qformer_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).qformer_tokenizer
|
||||
|
||||
def get_video_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||
@ -67,14 +69,14 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = InstructBlipVideoProcessor(
|
||||
tokenizer=self.get_tokenizer(),
|
||||
image_processor=self.get_image_processor(),
|
||||
video_processor=self.get_video_processor(),
|
||||
qformer_tokenizer=self.get_qformer_tokenizer(),
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
processor.save_pretrained(tmpdir)
|
||||
|
||||
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
||||
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
|
||||
video_processor_add_kwargs = self.get_video_processor(do_normalize=False, padding_value=1.0)
|
||||
|
||||
processor = InstructBlipVideoProcessor.from_pretrained(
|
||||
tmpdir, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
|
||||
@ -83,34 +85,34 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, PreTrainedTokenizerFast)
|
||||
|
||||
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.image_processor, InstructBlipVideoImageProcessor)
|
||||
self.assertEqual(processor.video_processor.to_json_string(), video_processor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.video_processor, InstructBlipVideoVideoProcessor)
|
||||
self.assertIsInstance(processor.qformer_tokenizer, BertTokenizerFast)
|
||||
|
||||
def test_image_processor(self):
|
||||
image_processor = self.get_image_processor()
|
||||
def test_video_processor(self):
|
||||
video_processor = self.get_video_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||
|
||||
processor = InstructBlipVideoProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
tokenizer=tokenizer, video_processor=video_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
input_feat_extract = image_processor(image_input, return_tensors="np")
|
||||
input_processor = processor(images=image_input, return_tensors="np")
|
||||
input_feat_extract = video_processor(image_input, return_tensors="pt")
|
||||
input_processor = processor(images=image_input, return_tensors="pt")
|
||||
|
||||
for key in input_feat_extract.keys():
|
||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||
|
||||
def test_tokenizer(self):
|
||||
image_processor = self.get_image_processor()
|
||||
video_processor = self.get_video_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||
|
||||
processor = InstructBlipVideoProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
tokenizer=tokenizer, video_processor=video_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
|
||||
input_str = ["lower newer"]
|
||||
@ -127,12 +129,12 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertListEqual(encoded_tokens_qformer[key], encoded_processor["qformer_" + key])
|
||||
|
||||
def test_processor(self):
|
||||
image_processor = self.get_image_processor()
|
||||
video_processor = self.get_video_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||
|
||||
processor = InstructBlipVideoProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
tokenizer=tokenizer, video_processor=video_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
|
||||
input_str = "lower newer"
|
||||
@ -150,12 +152,12 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor()
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
image_processor = self.get_image_processor()
|
||||
video_processor = self.get_video_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||
|
||||
processor = InstructBlipVideoProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
tokenizer=tokenizer, video_processor=video_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
|
||||
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
||||
@ -166,12 +168,12 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertListEqual(decoded_tok, decoded_processor)
|
||||
|
||||
def test_model_input_names(self):
|
||||
image_processor = self.get_image_processor()
|
||||
video_processor = self.get_video_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
qformer_tokenizer = self.get_qformer_tokenizer()
|
||||
|
||||
processor = InstructBlipVideoProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer
|
||||
tokenizer=tokenizer, video_processor=video_processor, qformer_tokenizer=qformer_tokenizer
|
||||
)
|
||||
|
||||
input_str = "lower newer"
|
||||
|
@ -0,0 +1,116 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
if is_torchvision_available():
|
||||
from transformers import InstructBlipVideoVideoProcessor
|
||||
|
||||
|
||||
class InstructBlipVideoVideoProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=5,
|
||||
num_channels=3,
|
||||
num_frames=4,
|
||||
min_resolution=30,
|
||||
max_resolution=80,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_normalize=True,
|
||||
image_mean=OPENAI_CLIP_MEAN,
|
||||
image_std=OPENAI_CLIP_STD,
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
super().__init__()
|
||||
size = size if size is not None else {"height": 18, "width": 18}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_frames = num_frames
|
||||
self.num_channels = num_channels
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_video_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def expected_output_video_shape(self, images):
|
||||
return self.num_frames, self.num_channels, self.size["height"], self.size["width"]
|
||||
|
||||
def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"):
|
||||
videos = prepare_video_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_frames=self.num_frames,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class InstructBlipVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
|
||||
fast_video_processing_class = InstructBlipVideoVideoProcessor if is_torchvision_available() else None
|
||||
input_name = "pixel_values"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.video_processor_tester = InstructBlipVideoVideoProcessingTester(self)
|
||||
|
||||
@property
|
||||
def video_processor_dict(self):
|
||||
return self.video_processor_tester.prepare_video_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
video_processing = self.fast_video_processing_class(**self.video_processor_dict)
|
||||
self.assertTrue(hasattr(video_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(video_processing, "size"))
|
||||
self.assertTrue(hasattr(video_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(video_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(video_processing, "image_std"))
|
||||
self.assertTrue(hasattr(video_processing, "do_convert_rgb"))
|
||||
|
||||
def test_video_processor_from_dict_with_kwargs(self):
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict)
|
||||
self.assertEqual(video_processor.size, {"height": 18, "width": 18})
|
||||
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict, size=42)
|
||||
self.assertEqual(video_processor.size, {"height": 42, "width": 42})
|
@ -18,12 +18,13 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer, InternVLProcessor
|
||||
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
from ...test_processing_common import MODALITY_INPUT_DATA, ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -31,7 +32,7 @@ if is_torch_available():
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import GotOcr2ImageProcessor
|
||||
from transformers import GotOcr2ImageProcessor, InternVLVideoProcessor
|
||||
|
||||
|
||||
@require_vision
|
||||
@ -55,12 +56,22 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
image_std=[0.229, 0.224, 0.225],
|
||||
do_convert_rgb=True,
|
||||
)
|
||||
video_processor = InternVLVideoProcessor(
|
||||
do_resize=True,
|
||||
size={"height": 20, "width": 20},
|
||||
do_rescale=True,
|
||||
rescale_factor=1 / 255,
|
||||
do_normalize=True,
|
||||
image_mean=[0.485, 0.456, 0.406],
|
||||
image_std=[0.229, 0.224, 0.225],
|
||||
do_convert_rgb=True,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("OpenGVLab/InternVL3-1B-hf", padding_side="left")
|
||||
processor_kwargs = cls.prepare_processor_dict()
|
||||
processor = InternVLProcessor.from_pretrained(
|
||||
"OpenGVLab/InternVL3-1B-hf",
|
||||
processor = InternVLProcessor(
|
||||
image_processor=image_processor,
|
||||
tokenizer=tokenizer,
|
||||
video_processor=video_processor,
|
||||
**processor_kwargs,
|
||||
)
|
||||
processor.save_pretrained(cls.tmpdirname)
|
||||
@ -69,7 +80,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def prepare_processor_dict():
|
||||
return {"image_seq_length": 10}
|
||||
return {"image_seq_length": 2}
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
@ -77,6 +88,9 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def get_video_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
||||
|
||||
def get_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
@ -168,6 +182,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
|
||||
# Override video chat_template tests as InternVLProcessor returns flattened video features
|
||||
@require_av
|
||||
@require_torch
|
||||
def test_apply_chat_template_video_special_processing(self):
|
||||
"""
|
||||
Tests that models can use their own preprocessing to preprocess conversations.
|
||||
@ -225,7 +240,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="np",
|
||||
return_tensors="pt",
|
||||
num_frames=8,
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
@ -236,6 +251,8 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
# Difference with common tests, InternVLProcessor returns flattened video features, and uses 8 frames by default
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8)
|
||||
|
||||
@require_torch
|
||||
@require_av
|
||||
def test_apply_chat_template_video_frame_sampling(self):
|
||||
processor = self.get_processor()
|
||||
|
||||
@ -271,7 +288,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
num_frames=num_frames,
|
||||
return_tensors="np",
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), num_frames)
|
||||
@ -284,6 +301,7 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 300)
|
||||
@ -302,6 +320,97 @@ class InternVLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 2)
|
||||
|
||||
@require_av
|
||||
@parameterized.expand([(1, "pt"), (2, "pt")])
|
||||
def test_apply_chat_template_video(self, batch_size: int, return_tensors: str):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
if "video_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"`video_processor` attribute not present in {self.processor_class}")
|
||||
|
||||
batch_messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Describe this."}],
|
||||
},
|
||||
]
|
||||
] * batch_size
|
||||
|
||||
# Test that jinja can be applied
|
||||
formatted_prompt = processor.apply_chat_template(batch_messages, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), batch_size)
|
||||
|
||||
# Test that tokenizing with template and directly with `self.tokenizer` gives same output
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||
batch_messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
|
||||
)
|
||||
add_special_tokens = True
|
||||
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
|
||||
add_special_tokens = False
|
||||
tok_output = processor.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=add_special_tokens)
|
||||
expected_output = tok_output.input_ids
|
||||
self.assertListEqual(expected_output.tolist(), formatted_prompt_tokenized.tolist())
|
||||
|
||||
# Test that kwargs passed to processor's `__call__` are actually used
|
||||
tokenized_prompt_100 = processor.apply_chat_template(
|
||||
batch_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
max_length=100,
|
||||
)
|
||||
self.assertEqual(len(tokenized_prompt_100[0]), 100)
|
||||
|
||||
# Test that `return_dict=True` returns text related inputs in the dict
|
||||
out_dict_text = processor.apply_chat_template(
|
||||
batch_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertTrue(all(key in out_dict_text for key in ["input_ids", "attention_mask"]))
|
||||
self.assertEqual(len(out_dict_text["input_ids"]), batch_size)
|
||||
self.assertEqual(len(out_dict_text["attention_mask"]), batch_size)
|
||||
|
||||
# Test that with modality URLs and `return_dict=True`, we get modality inputs in the dict
|
||||
for idx, url in enumerate(MODALITY_INPUT_DATA["videos"][:batch_size]):
|
||||
batch_messages[idx][0]["content"] = [batch_messages[idx][0]["content"][0], {"type": "video", "url": url}]
|
||||
|
||||
out_dict = processor.apply_chat_template(
|
||||
batch_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
num_frames=4, # by default no more than 4 frames, otherwise too slow
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict)
|
||||
self.assertEqual(len(out_dict["input_ids"]), batch_size)
|
||||
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
|
||||
|
||||
video_len = 4 if batch_size == 1 else 3 # InternVL patches out and removes frames after processing
|
||||
self.assertEqual(len(out_dict[self.videos_input_name]), video_len)
|
||||
for k in out_dict:
|
||||
self.assertIsInstance(out_dict[k], torch.Tensor)
|
||||
|
||||
# Test continue from final message
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "It is the sound of"}],
|
||||
}
|
||||
for batch_idx in range(batch_size):
|
||||
batch_messages[batch_idx] = batch_messages[batch_idx] + [assistant_message]
|
||||
continue_prompt = processor.apply_chat_template(batch_messages, continue_final_message=True, tokenize=False)
|
||||
for prompt in continue_prompt:
|
||||
self.assertTrue(prompt.endswith("It is the sound of")) # no `eos` token at the end
|
||||
|
107
tests/models/internvl/test_video_processor_internvl.py
Normal file
107
tests/models/internvl/test_video_processor_internvl.py
Normal file
@ -0,0 +1,107 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
pass
|
||||
|
||||
if is_vision_available():
|
||||
if is_torchvision_available():
|
||||
from transformers import InternVLVideoProcessor
|
||||
|
||||
|
||||
class InternVLVideoProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=5,
|
||||
num_frames=8,
|
||||
num_channels=3,
|
||||
min_resolution=30,
|
||||
max_resolution=80,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_normalize=True,
|
||||
image_mean=OPENAI_CLIP_MEAN,
|
||||
image_std=OPENAI_CLIP_STD,
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
size = size if size is not None else {"height": 384, "width": 384}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_frames = num_frames
|
||||
self.num_channels = num_channels
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_video_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def expected_output_video_shape(self, videos):
|
||||
return [self.num_frames, self.num_channels, self.size["height"], self.size["width"]]
|
||||
|
||||
def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"):
|
||||
videos = prepare_video_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_frames=self.num_frames,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class InternVLVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
|
||||
fast_video_processing_class = InternVLVideoProcessor if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.video_processor_tester = InternVLVideoProcessingTester(self)
|
||||
|
||||
@property
|
||||
def video_processor_dict(self):
|
||||
return self.video_processor_tester.prepare_video_processor_dict()
|
||||
|
||||
def test_video_processor_from_dict_with_kwargs(self):
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict)
|
||||
self.assertEqual(video_processor.size, {"height": 384, "width": 384})
|
||||
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict, size=42)
|
||||
self.assertEqual(video_processor.size, {"height": 42, "width": 42})
|
@ -1,218 +0,0 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import LlavaNextVideoImageProcessor
|
||||
|
||||
|
||||
class LlavaNextVideoProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=5,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=80,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_center_crop=True,
|
||||
crop_size=None,
|
||||
do_normalize=True,
|
||||
image_mean=OPENAI_CLIP_MEAN,
|
||||
image_std=OPENAI_CLIP_STD,
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
size = size if size is not None else {"shortest_edge": 20}
|
||||
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_center_crop": self.do_center_crop,
|
||||
"crop_size": self.crop_size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.expected_output_image_shape
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.num_channels, self.crop_size["height"], self.crop_size["width"]
|
||||
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
def prepare_video_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
images = prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
# let's simply copy the frames to fake a long video-clip
|
||||
if numpify or torchify:
|
||||
videos = []
|
||||
for image in images:
|
||||
if numpify:
|
||||
video = image[None, ...].repeat(8, 0)
|
||||
else:
|
||||
video = image[None, ...].repeat(8, 1, 1, 1)
|
||||
videos.append(video)
|
||||
else:
|
||||
videos = []
|
||||
for pil_image in images:
|
||||
videos.append([pil_image] * 8)
|
||||
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class LlavaNextVideoProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = LlavaNextVideoImageProcessor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = LlavaNextVideoProcessingTester(self)
|
||||
|
||||
@property
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.test_image_processor_from_dict_with_kwargs
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 20})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
|
||||
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video[0], Image.Image)
|
||||
|
||||
# Test not batched input (pass as `videos` arg to test that ImageProcessor can handle videos in absence of images!)
|
||||
encoded_videos = image_processing(images=video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (1, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(images=video_inputs, return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (5, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True, numpify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, np.ndarray)
|
||||
|
||||
# Test not batched input (pass as `videos` arg to test that ImageProcessor can handle videos in absence of images!)
|
||||
encoded_videos = image_processing(images=video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (1, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(images=video_inputs, return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (5, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True, torchify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = image_processing(images=video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (1, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(images=video_inputs, return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (5, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
@unittest.skip("LlavaNextVideoImageProcessor doesn't treat 4 channel PIL and numpy consistently yet")
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
@ -19,13 +19,16 @@ import unittest
|
||||
|
||||
from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextVideoProcessor
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import LlavaNextImageProcessor, LlavaNextVideoImageProcessor
|
||||
from transformers import LlavaNextImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import LlavaNextVideoVideoProcessor
|
||||
|
||||
if is_torch_available:
|
||||
pass
|
||||
@ -39,7 +42,7 @@ class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def setUpClass(cls):
|
||||
cls.tmpdirname = tempfile.mkdtemp()
|
||||
image_processor = LlavaNextImageProcessor()
|
||||
video_processor = LlavaNextVideoImageProcessor()
|
||||
video_processor = LlavaNextVideoVideoProcessor()
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>", "<video>"]})
|
||||
processor_kwargs = cls.prepare_processor_dict()
|
||||
|
@ -0,0 +1,127 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
pass
|
||||
|
||||
if is_vision_available():
|
||||
if is_torchvision_available():
|
||||
from transformers import LlavaNextVideoVideoProcessor
|
||||
|
||||
|
||||
class LlavaNextVideoProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=5,
|
||||
num_frames=8,
|
||||
num_channels=3,
|
||||
min_resolution=30,
|
||||
max_resolution=80,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_center_crop=True,
|
||||
crop_size=None,
|
||||
do_normalize=True,
|
||||
image_mean=OPENAI_CLIP_MEAN,
|
||||
image_std=OPENAI_CLIP_STD,
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
size = size if size is not None else {"height": 20, "width": 20}
|
||||
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_frames = num_frames
|
||||
self.num_channels = num_channels
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_video_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_center_crop": self.do_center_crop,
|
||||
"crop_size": self.crop_size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def expected_output_video_shape(self, images):
|
||||
return self.num_frames, self.num_channels, self.crop_size["height"], self.crop_size["width"]
|
||||
|
||||
def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"):
|
||||
videos = prepare_video_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_frames=self.num_frames,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class LlavaNextVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
|
||||
fast_video_processing_class = LlavaNextVideoVideoProcessor if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.video_processor_tester = LlavaNextVideoProcessingTester(self)
|
||||
|
||||
@property
|
||||
def video_processor_dict(self):
|
||||
return self.video_processor_tester.prepare_video_processor_dict()
|
||||
|
||||
def test_video_processor_properties(self):
|
||||
video_processing = self.fast_video_processing_class(**self.video_processor_dict)
|
||||
self.assertTrue(hasattr(video_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(video_processing, "size"))
|
||||
self.assertTrue(hasattr(video_processing, "do_center_crop"))
|
||||
self.assertTrue(hasattr(video_processing, "center_crop"))
|
||||
self.assertTrue(hasattr(video_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(video_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(video_processing, "image_std"))
|
||||
self.assertTrue(hasattr(video_processing, "do_convert_rgb"))
|
||||
|
||||
def test_video_processor_from_dict_with_kwargs(self):
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict)
|
||||
self.assertEqual(video_processor.size, {"height": 20, "width": 20})
|
||||
self.assertEqual(video_processor.crop_size, {"height": 18, "width": 18})
|
||||
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict, size=42, crop_size=84)
|
||||
self.assertEqual(video_processor.size, {"shortest_edge": 42})
|
||||
self.assertEqual(video_processor.crop_size, {"height": 84, "width": 84})
|
@ -32,7 +32,7 @@ if is_vision_available():
|
||||
from transformers import LlavaOnevisionImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import LlavaOnevisionImageProcessorFast, LlavaOnevisionVideoProcessor
|
||||
from transformers import LlavaOnevisionImageProcessorFast
|
||||
|
||||
|
||||
class LlavaOnevisionImageProcessingTester:
|
||||
@ -91,41 +91,12 @@ class LlavaOnevisionImageProcessingTester:
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
# Copied from tests.models.llava_next_video.test_image_processing_llava_next_video.LlavaNextVideoProcessingTester.prepare_video_inputs
|
||||
def prepare_video_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
images = prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
# let's simply copy the frames to fake a long video-clip
|
||||
if numpify or torchify:
|
||||
videos = []
|
||||
for image in images:
|
||||
if numpify:
|
||||
video = image[None, ...].repeat(8, 0)
|
||||
else:
|
||||
video = image[None, ...].repeat(8, 1, 1, 1)
|
||||
videos.append(video)
|
||||
else:
|
||||
videos = []
|
||||
for pil_image in images:
|
||||
videos.append([pil_image] * 8)
|
||||
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = LlavaOnevisionImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = LlavaOnevisionImageProcessorFast if is_torchvision_available() else None
|
||||
video_processing_class = LlavaOnevisionVideoProcessor if is_vision_available() else None
|
||||
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->LlavaOnevision
|
||||
def setUp(self):
|
||||
@ -148,15 +119,6 @@ class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestC
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
self.assertTrue(hasattr(image_processing, "image_grid_pinpoints"))
|
||||
|
||||
def test_video_processor_properties(self):
|
||||
image_processing = self.video_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
@ -248,58 +210,6 @@ class LlavaOnevisionImageProcessingTest(ImageProcessingTestMixin, unittest.TestC
|
||||
# Image processor should return same pixel values, independently of input format
|
||||
self.assertTrue((encoded_images_nested == encoded_images).all())
|
||||
|
||||
def test_call_pil_video(self):
|
||||
# Initialize image_processing
|
||||
video_processing = self.video_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video[0], Image.Image)
|
||||
|
||||
encoded_videos = video_processing(video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (1, 8, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = video_processing(video_inputs, return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (7, 8, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_numpy_video(self):
|
||||
# Initialize image_processing
|
||||
video_processing = self.video_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True, numpify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, np.ndarray)
|
||||
|
||||
encoded_videos = video_processing(video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (1, 8, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = video_processing(video_inputs, return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (7, 8, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_pytorch_video(self):
|
||||
# Initialize image_processing
|
||||
video_processing = self.video_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True, torchify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = video_processing(video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (1, 8, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = video_processing(video_inputs, return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (7, 8, 3, 20, 20)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
@unittest.skip(
|
||||
reason="LlavaOnevisionImageProcessorFast doesn't compile (infinitely) when using class transforms"
|
||||
) # FIXME yoni
|
||||
|
@ -16,8 +16,8 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
@ -27,15 +27,18 @@ if is_vision_available():
|
||||
AutoProcessor,
|
||||
LlavaOnevisionImageProcessor,
|
||||
LlavaOnevisionProcessor,
|
||||
LlavaOnevisionVideoProcessor,
|
||||
Qwen2TokenizerFast,
|
||||
)
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import LlavaOnevisionVideoProcessor
|
||||
|
||||
if is_torch_available:
|
||||
pass
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = LlavaOnevisionProcessor
|
||||
|
||||
|
@ -0,0 +1,116 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
pass
|
||||
|
||||
if is_vision_available():
|
||||
if is_torchvision_available():
|
||||
from transformers import LlavaOnevisionVideoProcessor
|
||||
|
||||
|
||||
class LlavaOnevisionVideoProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_frames=8,
|
||||
num_channels=3,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_normalize=True,
|
||||
image_mean=OPENAI_CLIP_MEAN,
|
||||
image_std=OPENAI_CLIP_STD,
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
size = size if size is not None else {"height": 20, "width": 20}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_frames = num_frames
|
||||
self.num_channels = num_channels
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_video_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def expected_output_video_shape(self, video):
|
||||
return self.num_frames, self.num_channels, self.size["height"], self.size["width"]
|
||||
|
||||
def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"):
|
||||
videos = prepare_video_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_frames=self.num_frames,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class LlavaOnevisionVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
|
||||
fast_video_processing_class = LlavaOnevisionVideoProcessor if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.video_processor_tester = LlavaOnevisionVideoProcessingTester(self)
|
||||
|
||||
@property
|
||||
def video_processor_dict(self):
|
||||
return self.video_processor_tester.prepare_video_processor_dict()
|
||||
|
||||
def test_video_processor_properties(self):
|
||||
video_processing = self.fast_video_processing_class(**self.video_processor_dict)
|
||||
self.assertTrue(hasattr(video_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(video_processing, "size"))
|
||||
self.assertTrue(hasattr(video_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(video_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(video_processing, "image_std"))
|
||||
self.assertTrue(hasattr(video_processing, "do_convert_rgb"))
|
||||
|
||||
def test_video_processor_from_dict_with_kwargs(self):
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict)
|
||||
self.assertEqual(video_processor.size, {"height": 20, "width": 20})
|
||||
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict, size=42)
|
||||
self.assertEqual(video_processor.size, {"shortest_edge": 42})
|
@ -16,7 +16,7 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
import numpy as np
|
||||
|
||||
from transformers import PixtralProcessor
|
||||
from transformers.testing_utils import require_vision
|
||||
@ -30,7 +30,7 @@ if is_torch_available():
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
pass
|
||||
|
||||
|
||||
@require_vision
|
||||
@ -42,11 +42,10 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.url_0 = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
cls.image_0 = Image.open(requests.get(cls.url_0, stream=True).raw)
|
||||
cls.image_0 = np.random.randint(255, size=(3, 876, 1300), dtype=np.uint8)
|
||||
cls.url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
cls.image_1 = Image.open(requests.get(cls.url_1, stream=True).raw)
|
||||
cls.url_2 = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
|
||||
cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw)
|
||||
cls.image_1 = np.random.randint(255, size=(3, 480, 640), dtype=np.uint8)
|
||||
cls.image_2 = np.random.randint(255, size=(3, 1024, 1024), dtype=np.uint8)
|
||||
|
||||
cls.tmpdirname = tempfile.mkdtemp()
|
||||
cls.addClassCleanup(lambda tempdir=cls.tmpdirname: shutil.rmtree(tempdir))
|
||||
|
@ -15,7 +15,7 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers.testing_utils import require_vision
|
||||
@ -25,8 +25,6 @@ from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import PixtralProcessor
|
||||
|
||||
|
||||
@ -37,11 +35,10 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.url_0 = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
cls.image_0 = Image.open(requests.get(cls.url_0, stream=True).raw)
|
||||
cls.image_0 = np.random.randint(255, size=(3, 876, 1300), dtype=np.uint8)
|
||||
cls.url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
cls.image_1 = Image.open(requests.get(cls.url_1, stream=True).raw)
|
||||
cls.url_2 = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
|
||||
cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw)
|
||||
cls.image_1 = np.random.randint(255, size=(3, 480, 640), dtype=np.uint8)
|
||||
cls.image_2 = np.random.randint(255, size=(3, 1024, 1024), dtype=np.uint8)
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
@ -64,8 +64,12 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
video_processor = self.get_component("video_processor")
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
|
||||
tokenizer=tokenizer,
|
||||
video_processor=video_processor,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
@ -91,8 +95,12 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
video_processor = self.get_component("video_processor")
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
|
||||
tokenizer=tokenizer,
|
||||
video_processor=video_processor,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
@ -125,8 +133,12 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
video_processor = self.get_component("video_processor")
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
|
||||
tokenizer=tokenizer,
|
||||
video_processor=video_processor,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
@ -159,7 +171,13 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor)
|
||||
video_processor = self.get_component("video_processor")
|
||||
_ = self.processor_class(
|
||||
tokenizer=tokenizer,
|
||||
video_processor=video_processor,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
) # Why delete test? TODO: raushan double check tests after cleaning model
|
||||
|
||||
@require_torch
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs_audio(self):
|
||||
@ -175,7 +193,13 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor)
|
||||
video_processor = self.get_component("video_processor")
|
||||
_ = self.processor_class(
|
||||
tokenizer=tokenizer,
|
||||
video_processor=video_processor,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -190,6 +214,9 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def get_video_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
||||
|
||||
def get_feature_extractor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).feature_extractor
|
||||
|
||||
@ -212,10 +239,14 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
processor = Qwen2_5OmniProcessor(
|
||||
image_processor=image_processor, feature_extractor=feature_extractor, tokenizer=tokenizer
|
||||
video_processor = self.get_video_processor()
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer,
|
||||
video_processor=video_processor,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
processor = Qwen2_5OmniProcessor.from_pretrained(self.tmpdirname, use_fast=False)
|
||||
|
||||
@ -230,9 +261,12 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
processor = Qwen2_5OmniProcessor(
|
||||
image_processor=image_processor, feature_extractor=feature_extractor, tokenizer=tokenizer
|
||||
video_processor = self.get_video_processor()
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer,
|
||||
video_processor=video_processor,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
@ -247,9 +281,12 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
processor = Qwen2_5OmniProcessor(
|
||||
image_processor=image_processor, feature_extractor=feature_extractor, tokenizer=tokenizer
|
||||
video_processor = self.get_video_processor()
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer,
|
||||
video_processor=video_processor,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
input_str = "lower newer"
|
||||
@ -281,9 +318,12 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
processor = Qwen2_5OmniProcessor(
|
||||
image_processor=image_processor, feature_extractor=feature_extractor, tokenizer=tokenizer
|
||||
video_processor = self.get_video_processor()
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer,
|
||||
video_processor=video_processor,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
input_str = "lower newer"
|
||||
@ -377,7 +417,10 @@ class Qwen2_5OmniProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(input_name in out_dict)
|
||||
self.assertEqual(len(out_dict["input_ids"]), batch_size)
|
||||
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
|
||||
self.assertEqual(len(out_dict[input_name]), batch_size * 1564)
|
||||
|
||||
video_len = 5760 if batch_size == 1 else 5808 # qwen pixels don't scale with bs same way as other models
|
||||
mm_len = batch_size * 1564 if modality == "image" else video_len
|
||||
self.assertEqual(len(out_dict[input_name]), mm_len)
|
||||
|
||||
return_tensor_to_type = {"pt": torch.Tensor, "np": np.ndarray, None: list}
|
||||
for k in out_dict:
|
||||
|
@ -55,6 +55,9 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def get_video_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
||||
|
||||
@staticmethod
|
||||
def prepare_processor_dict():
|
||||
return {
|
||||
@ -68,8 +71,11 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def test_save_load_pretrained_default(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
image_processor = self.get_image_processor()
|
||||
video_processor = self.get_video_processor()
|
||||
|
||||
processor = Qwen2_5_VLProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
processor = Qwen2_5_VLProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
processor = Qwen2_5_VLProcessor.from_pretrained(self.tmpdirname, use_fast=False)
|
||||
|
||||
@ -81,8 +87,11 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def test_image_processor(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
video_processor = self.get_video_processor()
|
||||
|
||||
processor = Qwen2_5_VLProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
processor = Qwen2_5_VLProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
@ -95,8 +104,11 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def test_processor(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
video_processor = self.get_video_processor()
|
||||
|
||||
processor = Qwen2_5_VLProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
processor = Qwen2_5_VLProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
@ -118,8 +130,11 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def test_model_input_names(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
video_processor = self.get_video_processor()
|
||||
|
||||
processor = Qwen2_5_VLProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
processor = Qwen2_5_VLProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
@ -130,6 +145,7 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||
|
||||
@require_torch
|
||||
@require_av
|
||||
def _test_apply_chat_template(
|
||||
self,
|
||||
modality: str,
|
||||
@ -212,7 +228,10 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(input_name in out_dict)
|
||||
self.assertEqual(len(out_dict["input_ids"]), batch_size)
|
||||
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
|
||||
self.assertEqual(len(out_dict[input_name]), batch_size * 192)
|
||||
|
||||
video_len = 360 if batch_size == 1 else 320 # qwen pixels don't scale with bs same way as other models
|
||||
mm_len = batch_size * 192 if modality == "image" else video_len
|
||||
self.assertEqual(len(out_dict[input_name]), mm_len)
|
||||
|
||||
return_tensor_to_type = {"pt": torch.Tensor, "np": np.ndarray, None: list}
|
||||
for k in out_dict:
|
||||
@ -394,7 +413,7 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="np",
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
|
||||
|
@ -21,7 +21,7 @@ import requests
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs, prepare_video_inputs
|
||||
|
||||
@ -34,8 +34,8 @@ if is_vision_available():
|
||||
|
||||
from transformers import Qwen2VLImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import Qwen2VLImageProcessorFast
|
||||
# if is_torchvision_available():
|
||||
# from transformers import Qwen2VLImageProcessorFast
|
||||
|
||||
|
||||
class Qwen2VLImageProcessingTester:
|
||||
@ -118,7 +118,7 @@ class Qwen2VLImageProcessingTester:
|
||||
@require_vision
|
||||
class Qwen2VLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = Qwen2VLImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = Qwen2VLImageProcessorFast if is_torchvision_available() else None
|
||||
# fast_image_processing_class = Qwen2VLImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
@ -23,7 +23,7 @@ from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import AutoProcessor, Qwen2Tokenizer
|
||||
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
@ -31,6 +31,9 @@ from ...test_processing_common import ProcessorTesterMixin
|
||||
if is_vision_available():
|
||||
from transformers import Qwen2VLImageProcessor, Qwen2VLProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import Qwen2VLVideoProcessor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
@ -55,6 +58,9 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def get_video_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
||||
|
||||
@staticmethod
|
||||
def prepare_processor_dict():
|
||||
return {"chat_template": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"} # fmt: skip
|
||||
@ -66,8 +72,11 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def test_save_load_pretrained_default(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
image_processor = self.get_image_processor()
|
||||
video_processor = self.get_video_processor()
|
||||
|
||||
processor = Qwen2VLProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
processor = Qwen2VLProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
processor = Qwen2VLProcessor.from_pretrained(self.tmpdirname, use_fast=False)
|
||||
|
||||
@ -75,12 +84,16 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string())
|
||||
self.assertIsInstance(processor.tokenizer, Qwen2Tokenizer)
|
||||
self.assertIsInstance(processor.image_processor, Qwen2VLImageProcessor)
|
||||
self.assertIsInstance(processor.video_processor, Qwen2VLVideoProcessor)
|
||||
|
||||
def test_image_processor(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
video_processor = self.get_video_processor()
|
||||
|
||||
processor = Qwen2VLProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
processor = Qwen2VLProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
@ -93,8 +106,11 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def test_processor(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
video_processor = self.get_video_processor()
|
||||
|
||||
processor = Qwen2VLProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
processor = Qwen2VLProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
@ -113,8 +129,11 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def test_model_input_names(self):
|
||||
image_processor = self.get_image_processor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
video_processor = self.get_video_processor()
|
||||
|
||||
processor = Qwen2VLProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
processor = Qwen2VLProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor
|
||||
)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
@ -125,6 +144,7 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||
|
||||
@require_torch
|
||||
@require_av
|
||||
def _test_apply_chat_template(
|
||||
self,
|
||||
modality: str,
|
||||
@ -207,7 +227,10 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(input_name in out_dict)
|
||||
self.assertEqual(len(out_dict["input_ids"]), batch_size)
|
||||
self.assertEqual(len(out_dict["attention_mask"]), batch_size)
|
||||
self.assertEqual(len(out_dict[input_name]), batch_size * 192)
|
||||
|
||||
video_len = 360 if batch_size == 1 else 320 # qwen pixels don't scale with bs same way as other models
|
||||
mm_len = batch_size * 192 if modality == "image" else video_len
|
||||
self.assertEqual(len(out_dict[input_name]), mm_len)
|
||||
|
||||
return_tensor_to_type = {"pt": torch.Tensor, "np": np.ndarray, None: list}
|
||||
for k in out_dict:
|
||||
@ -373,7 +396,7 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="np",
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
|
||||
|
291
tests/models/qwen2_vl/test_video_processing_qwen2_vl.py
Normal file
291
tests/models/qwen2_vl/test_video_processing_qwen2_vl.py
Normal file
@ -0,0 +1,291 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers.image_utils import get_image_size
|
||||
from transformers.models.qwen2_vl.video_processing_qwen2_vl import smart_resize
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import Qwen2VLVideoProcessor
|
||||
|
||||
|
||||
class Qwen2VLVideoProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=5,
|
||||
num_frames=8,
|
||||
num_channels=3,
|
||||
min_resolution=30,
|
||||
max_resolution=80,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_center_crop=True,
|
||||
crop_size=None,
|
||||
do_normalize=True,
|
||||
image_mean=OPENAI_CLIP_MEAN,
|
||||
image_std=OPENAI_CLIP_STD,
|
||||
do_convert_rgb=True,
|
||||
temporal_patch_size=2,
|
||||
patch_size=14,
|
||||
min_pixels=20 * 20,
|
||||
max_pixels=100 * 100,
|
||||
merge_size=2,
|
||||
):
|
||||
size = size if size is not None else {"shortest_edge": 20}
|
||||
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_frames = num_frames
|
||||
self.num_channels = num_channels
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.patch_size = patch_size
|
||||
self.min_pixels = min_pixels
|
||||
self.max_pixels = max_pixels
|
||||
self.merge_size = merge_size
|
||||
|
||||
def prepare_video_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"do_center_crop": self.do_center_crop,
|
||||
"crop_size": self.crop_size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
"temporal_patch_size": self.temporal_patch_size,
|
||||
"patch_size": self.patch_size,
|
||||
"min_pixels": self.min_pixels,
|
||||
"max_pixels": self.max_pixels,
|
||||
"merge_size": self.merge_size,
|
||||
}
|
||||
|
||||
@require_vision
|
||||
def expected_output_video_shape(self, videos):
|
||||
grid_t = self.num_frames // self.temporal_patch_size
|
||||
hidden_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size
|
||||
seq_len = 0
|
||||
for video in videos:
|
||||
if isinstance(video[0], Image.Image):
|
||||
video = np.stack([np.array(frame) for frame in video])
|
||||
height, width = get_image_size(video)
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=self.patch_size * self.merge_size,
|
||||
min_pixels=self.min_pixels,
|
||||
max_pixels=self.max_pixels,
|
||||
)
|
||||
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
||||
seq_len += grid_t * grid_h * grid_w
|
||||
return [seq_len, hidden_dim]
|
||||
|
||||
def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"):
|
||||
videos = prepare_video_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_frames=self.num_frames,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class Qwen2VLVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
|
||||
fast_video_processing_class = Qwen2VLVideoProcessor if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.video_processor_tester = Qwen2VLVideoProcessingTester(self)
|
||||
|
||||
@property
|
||||
def video_processor_dict(self):
|
||||
return self.video_processor_tester.prepare_video_processor_dict()
|
||||
|
||||
def test_video_processor_properties(self):
|
||||
video_processing = self.fast_video_processing_class(**self.video_processor_dict)
|
||||
self.assertTrue(hasattr(video_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(video_processing, "size"))
|
||||
self.assertTrue(hasattr(video_processing, "do_center_crop"))
|
||||
self.assertTrue(hasattr(video_processing, "center_crop"))
|
||||
self.assertTrue(hasattr(video_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(video_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(video_processing, "image_std"))
|
||||
self.assertTrue(hasattr(video_processing, "do_convert_rgb"))
|
||||
|
||||
# OVERRIDEN BECAUSE QWEN2_VL HAS SPECIAL OUTPUT SHAPES
|
||||
def test_video_processor_from_dict_with_kwargs(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
video_processor = video_processing_class(**self.video_processor_dict)
|
||||
self.assertEqual(video_processor.min_pixels, self.video_processor_tester.min_pixels)
|
||||
self.assertEqual(video_processor.max_pixels, self.video_processor_tester.max_pixels)
|
||||
|
||||
video_processor = video_processing_class.from_dict(
|
||||
self.video_processor_dict, min_pixels=256 * 256, max_pixels=640 * 640
|
||||
)
|
||||
self.assertEqual(video_processor.min_pixels, 256 * 256)
|
||||
self.assertEqual(video_processor.max_pixels, 640 * 640)
|
||||
|
||||
def test_call_pil(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
# Initialize video_processing
|
||||
video_processing = video_processing_class(**self.video_processor_dict)
|
||||
video_inputs = self.video_processor_tester.prepare_video_inputs(
|
||||
equal_resolution=False, return_tensors="pil"
|
||||
)
|
||||
|
||||
# Each video is a list of PIL Images
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video[0], Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = video_processing(video_inputs[0], return_tensors="pt")[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
|
||||
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = video_processing(video_inputs, return_tensors="pt")[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
||||
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_numpy(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
# Initialize video_processing
|
||||
video_processing = video_processing_class(**self.video_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.video_processor_tester.prepare_video_inputs(
|
||||
equal_resolution=False, return_tensors="np"
|
||||
)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = video_processing(video_inputs[0], return_tensors="pt")[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
|
||||
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = video_processing(video_inputs, return_tensors="pt")[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
||||
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
# Initialize video_processing
|
||||
video_processing = video_processing_class(**self.video_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
video_inputs = self.video_processor_tester.prepare_video_inputs(
|
||||
equal_resolution=False, return_tensors="torch"
|
||||
)
|
||||
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = video_processing(video_inputs[0], return_tensors="pt")[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
|
||||
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
||||
encoded_videos = video_processing(video_inputs, return_tensors="pt")[self.input_name]
|
||||
self.assertEqual(
|
||||
list(encoded_videos.shape),
|
||||
expected_output_video_shape,
|
||||
)
|
||||
|
||||
def test_nested_input(self):
|
||||
"""Tests that the processor can work with nested list where each video is a list of arrays"""
|
||||
for video_processing_class in self.video_processor_list:
|
||||
video_processing = video_processing_class(**self.video_processor_dict)
|
||||
video_inputs = self.video_processor_tester.prepare_video_inputs(
|
||||
equal_resolution=False, return_tensors="np"
|
||||
)
|
||||
|
||||
# Test not batched input
|
||||
video_inputs_nested = [list(video) for video in video_inputs]
|
||||
encoded_videos = video_processing(video_inputs_nested[0], return_tensors="pt")[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
|
||||
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
||||
encoded_videos = video_processing(video_inputs_nested, return_tensors="pt")[self.input_name]
|
||||
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
@unittest.skip("Skip for now, the test needs adjustment fo Qwen2VL")
|
||||
def test_call_numpy_4_channels(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
# Test that can process videos which have an arbitrary number of channels
|
||||
# Initialize video_processing
|
||||
video_processor = video_processing_class(**self.video_processor_dict)
|
||||
|
||||
# create random numpy tensors
|
||||
self.video_processor_tester.num_channels = 4
|
||||
video_inputs = self.video_processor_tester.prepare_video_inputs(
|
||||
equal_resolution=False, return_tensors="np"
|
||||
)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = video_processor(
|
||||
video_inputs[0],
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
)[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
|
||||
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = video_processor(
|
||||
video_inputs,
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
)[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
||||
self.assertEqual(list(encoded_videos.shape), expected_output_video_shape)
|
@ -22,7 +22,7 @@ import requests
|
||||
|
||||
from transformers import SmolVLMProcessor
|
||||
from transformers.models.auto.processing_auto import AutoProcessor
|
||||
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||
from transformers.testing_utils import is_flaky, require_av, require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
@ -63,6 +63,7 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
)
|
||||
cls.bos_token = processor.tokenizer.bos_token
|
||||
cls.image_token = processor.image_token
|
||||
cls.video_token = processor.image_token * 8 # SmolVLM uses image token and repeats it `num_frames` times
|
||||
cls.fake_image_token = processor.fake_image_token
|
||||
cls.global_img_token = processor.global_image_token
|
||||
|
||||
@ -79,6 +80,9 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def get_video_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
||||
|
||||
def get_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
@ -114,6 +118,10 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
def tearDownClass(cls):
|
||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||
|
||||
@is_flaky # fails 15 out of 100, FIXME @raushan
|
||||
def test_structured_kwargs_nested_from_dict_video(self):
|
||||
super().test_structured_kwargs_nested_from_dict_video()
|
||||
|
||||
def test_process_interleaved_images_prompts_no_image_splitting(self):
|
||||
processor_components = self.prepare_components()
|
||||
processor_components["tokenizer"] = self.get_component("tokenizer", padding_side="left")
|
||||
@ -433,10 +441,13 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
video_processor = self.get_component("video_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor, **processor_kwargs)
|
||||
processor = self.processor_class(
|
||||
tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor, **processor_kwargs
|
||||
)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = self.prepare_text_inputs(batch_size=2, modality="image")
|
||||
@ -556,3 +567,7 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
padding=True,
|
||||
max_length=20,
|
||||
)
|
||||
|
||||
@unittest.skip("SmolVLM cannot accept image URL as video frames, because it needs to know video fps and duration")
|
||||
def test_apply_chat_template_video_1(self):
|
||||
pass
|
||||
|
118
tests/models/smolvlm/test_video_processing_smolvlm.py
Normal file
118
tests/models/smolvlm/test_video_processing_smolvlm.py
Normal file
@ -0,0 +1,118 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
if is_torchvision_available():
|
||||
from transformers import SmolVLMVideoProcessor
|
||||
from transformers.models.smolvlm.video_processing_smolvlm import get_resize_output_image_size
|
||||
|
||||
|
||||
class SmolVLMVideoProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=5,
|
||||
num_frames=8,
|
||||
num_channels=3,
|
||||
min_resolution=30,
|
||||
max_resolution=80,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_normalize=True,
|
||||
image_mean=IMAGENET_STANDARD_MEAN,
|
||||
image_std=IMAGENET_STANDARD_STD,
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
size = size if size is not None else {"longest_edge": 20}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_frames = num_frames
|
||||
self.num_channels = num_channels
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_video_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def expected_output_video_shape(self, videos):
|
||||
max_height, max_width = 0, 0
|
||||
if not isinstance(videos[0], torch.Tensor):
|
||||
videos = [torch.tensor(np.array(video)).permute(0, -1, -3, -2) for video in videos]
|
||||
for video in videos:
|
||||
height, width = get_resize_output_image_size(video, self.size["longest_edge"])
|
||||
max_height = max(height, max_height)
|
||||
max_width = max(width, max_width)
|
||||
return [self.num_frames, self.num_channels, max_height, max_width]
|
||||
|
||||
def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"):
|
||||
videos = prepare_video_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_frames=self.num_frames,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class SmolVLMVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
|
||||
fast_video_processing_class = SmolVLMVideoProcessor if is_torchvision_available() else None
|
||||
input_name = "pixel_values"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.video_processor_tester = SmolVLMVideoProcessingTester(self)
|
||||
|
||||
@property
|
||||
def video_processor_dict(self):
|
||||
return self.video_processor_tester.prepare_video_processor_dict()
|
||||
|
||||
def test_video_processor_from_dict_with_kwargs(self):
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict)
|
||||
self.assertEqual(video_processor.size, {"longest_edge": 20})
|
||||
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict, size=42)
|
||||
self.assertEqual(video_processor.size, {"height": 42, "width": 42})
|
@ -1,327 +0,0 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import VideoLlavaImageProcessor
|
||||
|
||||
|
||||
class VideoLlavaImageProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=5,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=80,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_center_crop=True,
|
||||
crop_size=None,
|
||||
do_normalize=True,
|
||||
image_mean=OPENAI_CLIP_MEAN,
|
||||
image_std=OPENAI_CLIP_STD,
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
size = size if size is not None else {"shortest_edge": 20}
|
||||
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_center_crop": self.do_center_crop,
|
||||
"crop_size": self.crop_size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.expected_output_image_shape
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.num_channels, self.crop_size["height"], self.crop_size["width"]
|
||||
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
def prepare_video_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
images = prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
# let's simply copy the frames to fake a long video-clip
|
||||
if numpify or torchify:
|
||||
videos = []
|
||||
for image in images:
|
||||
if numpify:
|
||||
video = image[None, ...].repeat(8, 0)
|
||||
else:
|
||||
video = image[None, ...].repeat(8, 1, 1, 1)
|
||||
videos.append(video)
|
||||
else:
|
||||
videos = []
|
||||
for pil_image in images:
|
||||
videos.append([pil_image] * 8)
|
||||
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = VideoLlavaImageProcessor if is_vision_available() else None
|
||||
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->VideoLlava
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = VideoLlavaImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.test_image_processor_from_dict_with_kwargs
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 20})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
|
||||
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values_images
|
||||
expected_output_image_shape = (1, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values_images
|
||||
expected_output_image_shape = (5, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(images=image_inputs[0], return_tensors="pt").pixel_values_images
|
||||
expected_output_image_shape = (1, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(images=image_inputs, return_tensors="pt").pixel_values_images
|
||||
expected_output_image_shape = (5, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_numpy_videos(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(numpify=True, equal_resolution=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = image_processing(images=None, videos=video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (1, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(images=None, videos=video_inputs, return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (5, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_pil_videos(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# the inputs come in list of lists batched format
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video[0], Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = image_processing(images=None, videos=video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (1, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(images=None, videos=video_inputs, return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (5, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
||||
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values_images
|
||||
expected_output_image_shape = (1, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values_images
|
||||
expected_output_image_shape = (5, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_pytorch_videos(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True, torchify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = image_processing(images=None, videos=video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (1, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(images=None, videos=video_inputs, return_tensors="pt").pixel_values_videos
|
||||
expected_output_video_shape = (5, 8, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
@parameterized.expand([(True, False), (False, True)])
|
||||
def test_call_mixed(self, numpify, torchify):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(
|
||||
equal_resolution=True, numpify=numpify, torchify=torchify
|
||||
)
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True, torchify=torchify)
|
||||
|
||||
# Test not batched input
|
||||
encoded = image_processing(images=image_inputs[0], videos=video_inputs[0], return_tensors="pt")
|
||||
expected_output_video_shape = (1, 8, 3, 18, 18)
|
||||
expected_output_image_shape = (1, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded.pixel_values_videos.shape), expected_output_video_shape)
|
||||
self.assertEqual(tuple(encoded.pixel_values_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded = image_processing(images=image_inputs, videos=video_inputs, return_tensors="pt")
|
||||
expected_output_video_shape = (5, 8, 3, 18, 18)
|
||||
expected_output_image_shape = (5, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded.pixel_values_videos.shape), expected_output_video_shape)
|
||||
self.assertEqual(tuple(encoded.pixel_values_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_numpy_4_channels(self):
|
||||
# Test that can process images which have an arbitrary number of channels
|
||||
# Initialize image_processing
|
||||
image_processor = self.image_processing_class(**self.image_processor_dict)
|
||||
|
||||
# create random numpy tensors
|
||||
self.image_processor_tester.num_channels = 4
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processor(
|
||||
image_inputs[0],
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
).pixel_values_images
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape))
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processor(
|
||||
image_inputs,
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
).pixel_values_images
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
|
||||
self.assertEqual(
|
||||
tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape)
|
||||
)
|
122
tests/models/video_llava/test_video_processing_video_llava.py
Normal file
122
tests/models/video_llava/test_video_processing_video_llava.py
Normal file
@ -0,0 +1,122 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
pass
|
||||
|
||||
if is_vision_available():
|
||||
if is_torchvision_available():
|
||||
from transformers import VideoLlavaVideoProcessor
|
||||
|
||||
|
||||
class VideoLlavaVideoProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=5,
|
||||
num_frames=8,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=80,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_center_crop=True,
|
||||
crop_size=None,
|
||||
do_normalize=True,
|
||||
image_mean=OPENAI_CLIP_MEAN,
|
||||
image_std=OPENAI_CLIP_STD,
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
super().__init__()
|
||||
size = size if size is not None else {"shortest_edge": 20}
|
||||
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_frames = num_frames
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_video_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_center_crop": self.do_center_crop,
|
||||
"crop_size": self.crop_size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def expected_output_video_shape(self, images):
|
||||
return self.num_frames, self.num_channels, self.crop_size["height"], self.crop_size["width"]
|
||||
|
||||
def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"):
|
||||
videos = prepare_video_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_frames=self.num_frames,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class VideoLlavaVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
|
||||
fast_video_processing_class = VideoLlavaVideoProcessor if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.video_processor_tester = VideoLlavaVideoProcessingTester(self)
|
||||
|
||||
@property
|
||||
def video_processor_dict(self):
|
||||
return self.video_processor_tester.prepare_video_processor_dict()
|
||||
|
||||
def test_video_processor_properties(self):
|
||||
video_processing = self.fast_video_processing_class(**self.video_processor_dict)
|
||||
self.assertTrue(hasattr(video_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(video_processing, "size"))
|
||||
self.assertTrue(hasattr(video_processing, "do_center_crop"))
|
||||
self.assertTrue(hasattr(video_processing, "center_crop"))
|
||||
self.assertTrue(hasattr(video_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(video_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(video_processing, "image_std"))
|
||||
self.assertTrue(hasattr(video_processing, "do_convert_rgb"))
|
@ -179,7 +179,7 @@ class ImageProcessingTestMixin:
|
||||
|
||||
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-3)
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 5e-3
|
||||
)
|
||||
@ -205,7 +205,7 @@ class ImageProcessingTestMixin:
|
||||
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
|
||||
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
|
||||
|
||||
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
|
||||
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-3)
|
||||
self.assertLessEqual(
|
||||
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 5e-3
|
||||
)
|
||||
|
@ -539,7 +539,7 @@ class ProcessorTesterMixin:
|
||||
video_input = self.prepare_video_inputs()
|
||||
|
||||
inputs = processor(text=input_str, videos=video_input, return_tensors="pt")
|
||||
self.assertLessEqual(inputs[self.videos_input_name][0][0][0].mean(), 0)
|
||||
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
||||
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs_video(self):
|
||||
if "video_processor" not in self.processor_class.attributes:
|
||||
@ -574,7 +574,7 @@ class ProcessorTesterMixin:
|
||||
video_input = self.prepare_video_inputs()
|
||||
|
||||
inputs = processor(text=input_str, videos=video_input, do_rescale=True, rescale_factor=-1, return_tensors="pt")
|
||||
self.assertLessEqual(inputs[self.videos_input_name][0][0][0].mean(), 0)
|
||||
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
||||
|
||||
def test_unstructured_kwargs_video(self):
|
||||
if "video_processor" not in self.processor_class.attributes:
|
||||
@ -596,7 +596,7 @@ class ProcessorTesterMixin:
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertLessEqual(inputs[self.videos_input_name][0][0][0].mean(), 0)
|
||||
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|
||||
|
||||
def test_unstructured_kwargs_batched_video(self):
|
||||
@ -619,7 +619,7 @@ class ProcessorTesterMixin:
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertLessEqual(inputs[self.videos_input_name][0][0][0].mean(), 0)
|
||||
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
||||
self.assertTrue(
|
||||
len(inputs[self.text_input_name][0]) == len(inputs[self.text_input_name][1])
|
||||
and len(inputs[self.text_input_name][1]) < 76
|
||||
@ -665,7 +665,7 @@ class ProcessorTesterMixin:
|
||||
inputs = processor(text=input_str, videos=video_input, **all_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
self.assertLessEqual(inputs[self.videos_input_name][0][0][0].mean(), 0)
|
||||
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|
||||
|
||||
def test_structured_kwargs_nested_from_dict_video(self):
|
||||
@ -686,7 +686,7 @@ class ProcessorTesterMixin:
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, videos=video_input, **all_kwargs)
|
||||
self.assertLessEqual(inputs[self.videos_input_name][0][0][0].mean(), 0)
|
||||
self.assertLessEqual(inputs[self.videos_input_name][0].mean(), 0)
|
||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|
||||
|
||||
# TODO: the same test, but for audio + text processors that have strong overlap in kwargs
|
||||
@ -907,15 +907,15 @@ class ProcessorTesterMixin:
|
||||
for prompt in continue_prompt:
|
||||
self.assertTrue(prompt.endswith("It is the sound of")) # no `eos` token at the end
|
||||
|
||||
@require_av
|
||||
@require_librosa
|
||||
@parameterized.expand([(1, "np"), (1, "pt"), (2, "np"), (2, "pt")])
|
||||
def test_apply_chat_template_audio(self, batch_size: int, return_tensors: str):
|
||||
self._test_apply_chat_template(
|
||||
"audio", batch_size, return_tensors, "audio_input_name", "feature_extracttor", MODALITY_INPUT_DATA["audio"]
|
||||
)
|
||||
|
||||
@require_librosa
|
||||
@parameterized.expand([(1, "np"), (1, "pt"), (2, "np"), (2, "pt")])
|
||||
@require_av
|
||||
@parameterized.expand([(1, "pt"), (2, "pt")]) # video processor suports only torchvision
|
||||
def test_apply_chat_template_video(self, batch_size: int, return_tensors: str):
|
||||
self._test_apply_chat_template(
|
||||
"video", batch_size, return_tensors, "videos_input_name", "video_processor", MODALITY_INPUT_DATA["videos"]
|
||||
@ -927,6 +927,7 @@ class ProcessorTesterMixin:
|
||||
"image", batch_size, return_tensors, "images_input_name", "image_processor", MODALITY_INPUT_DATA["images"]
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_apply_chat_template_video_frame_sampling(self):
|
||||
processor = self.get_processor()
|
||||
|
||||
@ -962,7 +963,7 @@ class ProcessorTesterMixin:
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
num_frames=num_frames,
|
||||
return_tensors="np",
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||
@ -976,7 +977,7 @@ class ProcessorTesterMixin:
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=video_fps,
|
||||
return_tensors="np",
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||
@ -1024,6 +1025,7 @@ class ProcessorTesterMixin:
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)
|
||||
|
||||
@require_av
|
||||
@require_torch
|
||||
def test_apply_chat_template_video_special_processing(self):
|
||||
"""
|
||||
Tests that models can use their own preprocessing to preprocess conversations.
|
||||
@ -1081,7 +1083,7 @@ class ProcessorTesterMixin:
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="np",
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertTrue(self.videos_input_name in out_dict_with_video)
|
||||
|
||||
|
395
tests/test_video_processing_common.py
Normal file
395
tests/test_video_processing_common.py
Normal file
@ -0,0 +1,395 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
|
||||
from transformers import AutoVideoProcessor
|
||||
from transformers.testing_utils import (
|
||||
check_json_file_has_correct_format,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def prepare_video(num_frames, num_channels, width=10, height=10, return_tensors="pil"):
|
||||
"""This function prepares a video as a list of PIL images/NumPy arrays/PyTorch tensors."""
|
||||
|
||||
video = []
|
||||
for i in range(num_frames):
|
||||
video.append(np.random.randint(255, size=(width, height, num_channels), dtype=np.uint8))
|
||||
|
||||
if return_tensors == "pil":
|
||||
# PIL expects the channel dimension as last dimension
|
||||
video = [Image.fromarray(frame) for frame in video]
|
||||
elif return_tensors == "torch":
|
||||
# Torch images are typically in channels first format
|
||||
video = torch.tensor(video).permute(0, 3, 1, 2)
|
||||
elif return_tensors == "np":
|
||||
# Numpy images are typically in channels last format
|
||||
video = np.array(video)
|
||||
|
||||
return video
|
||||
|
||||
|
||||
def prepare_video_inputs(
|
||||
batch_size,
|
||||
num_frames,
|
||||
num_channels,
|
||||
min_resolution,
|
||||
max_resolution,
|
||||
equal_resolution=False,
|
||||
return_tensors="pil",
|
||||
):
|
||||
"""This function prepares a batch of videos: a list of list of PIL images, or a list of list of numpy arrays if
|
||||
one specifies return_tensors="np", or a list of list of PyTorch tensors if one specifies return_tensors="torch".
|
||||
|
||||
One can specify whether the videos are of the same resolution or not.
|
||||
"""
|
||||
|
||||
video_inputs = []
|
||||
for i in range(batch_size):
|
||||
if equal_resolution:
|
||||
width = height = max_resolution
|
||||
else:
|
||||
width, height = np.random.choice(np.arange(min_resolution, max_resolution), 2)
|
||||
video = prepare_video(
|
||||
num_frames=num_frames,
|
||||
num_channels=num_channels,
|
||||
width=width,
|
||||
height=height,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
video_inputs.append(video)
|
||||
|
||||
return video_inputs
|
||||
|
||||
|
||||
class VideoProcessingTestMixin:
|
||||
test_cast_dtype = None
|
||||
fast_video_processing_class = None
|
||||
video_processor_list = None
|
||||
input_name = "pixel_values_videos"
|
||||
|
||||
def setUp(self):
|
||||
video_processor_list = []
|
||||
|
||||
if self.fast_video_processing_class:
|
||||
video_processor_list.append(self.fast_video_processing_class)
|
||||
|
||||
self.video_processor_list = video_processor_list
|
||||
|
||||
def test_video_processor_to_json_string(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
video_processor = video_processing_class(**self.video_processor_dict)
|
||||
obj = json.loads(video_processor.to_json_string())
|
||||
for key, value in self.video_processor_dict.items():
|
||||
self.assertEqual(obj[key], value)
|
||||
|
||||
def test_video_processor_to_json_file(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
video_processor_first = video_processing_class(**self.video_processor_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
json_file_path = os.path.join(tmpdirname, "video_processor.json")
|
||||
video_processor_first.to_json_file(json_file_path)
|
||||
video_processor_second = video_processing_class.from_json_file(json_file_path)
|
||||
|
||||
self.assertEqual(video_processor_second.to_dict(), video_processor_first.to_dict())
|
||||
|
||||
def test_video_processor_from_dict_with_kwargs(self):
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict)
|
||||
self.assertEqual(video_processor.size, {"shortest_edge": 20})
|
||||
self.assertEqual(video_processor.crop_size, {"height": 18, "width": 18})
|
||||
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict, size=42, crop_size=84)
|
||||
self.assertEqual(video_processor.size, {"shortest_edge": 42})
|
||||
self.assertEqual(video_processor.crop_size, {"height": 84, "width": 84})
|
||||
|
||||
def test_video_processor_from_and_save_pretrained(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
video_processor_first = video_processing_class(**self.video_processor_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
saved_file = video_processor_first.save_pretrained(tmpdirname)[0]
|
||||
check_json_file_has_correct_format(saved_file)
|
||||
video_processor_second = video_processing_class.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertEqual(video_processor_second.to_dict(), video_processor_first.to_dict())
|
||||
|
||||
def test_video_processor_save_load_with_autovideoprocessor(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
video_processor_first = video_processing_class(**self.video_processor_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
saved_file = video_processor_first.save_pretrained(tmpdirname)[0]
|
||||
check_json_file_has_correct_format(saved_file)
|
||||
|
||||
use_fast = video_processing_class.__name__.endswith("Fast")
|
||||
video_processor_second = AutoVideoProcessor.from_pretrained(tmpdirname, use_fast=use_fast)
|
||||
|
||||
self.assertEqual(video_processor_second.to_dict(), video_processor_first.to_dict())
|
||||
|
||||
def test_init_without_params(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
video_processor = video_processing_class()
|
||||
self.assertIsNotNone(video_processor)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_vision
|
||||
def test_can_compile_fast_video_processor(self):
|
||||
if self.fast_video_processing_class is None:
|
||||
self.skipTest("Skipping compilation test as fast video processor is not defined")
|
||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
torch.compiler.reset()
|
||||
video_inputs = self.video_processor_tester.prepare_video_inputs(equal_resolution=False, return_tensors="torch")
|
||||
video_processor = self.fast_video_processing_class(**self.video_processor_dict)
|
||||
output_eager = video_processor(video_inputs, device=torch_device, return_tensors="pt")
|
||||
|
||||
video_processor = torch.compile(video_processor, mode="reduce-overhead")
|
||||
output_compiled = video_processor(video_inputs, device=torch_device, return_tensors="pt")
|
||||
|
||||
torch.testing.assert_close(
|
||||
output_eager[self.input_name], output_compiled[self.input_name], rtol=1e-4, atol=1e-4
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_cast_dtype_device(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
if self.test_cast_dtype is not None:
|
||||
# Initialize video_processor
|
||||
video_processor = video_processing_class(**self.video_processor_dict)
|
||||
|
||||
# create random PyTorch tensors
|
||||
video_inputs = self.video_processor_tester.prepare_video_inputs(
|
||||
equal_resolution=False, return_tensors="torch"
|
||||
)
|
||||
|
||||
encoding = video_processor(video_inputs, return_tensors="pt")
|
||||
|
||||
self.assertEqual(encoding[self.input_name].device, torch.device("cpu"))
|
||||
self.assertEqual(encoding[self.input_name].dtype, torch.float32)
|
||||
|
||||
encoding = video_processor(video_inputs, return_tensors="pt").to(torch.float16)
|
||||
self.assertEqual(encoding[self.input_name].device, torch.device("cpu"))
|
||||
self.assertEqual(encoding[self.input_name].dtype, torch.float16)
|
||||
|
||||
encoding = video_processor(video_inputs, return_tensors="pt").to("cpu", torch.bfloat16)
|
||||
self.assertEqual(encoding[self.input_name].device, torch.device("cpu"))
|
||||
self.assertEqual(encoding[self.input_name].dtype, torch.bfloat16)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
_ = video_processor(video_inputs, return_tensors="pt").to(torch.bfloat16, "cpu")
|
||||
|
||||
# Try with text + video feature
|
||||
encoding = video_processor(video_inputs, return_tensors="pt")
|
||||
encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])})
|
||||
encoding = encoding.to(torch.float16)
|
||||
|
||||
self.assertEqual(encoding[self.input_name].device, torch.device("cpu"))
|
||||
self.assertEqual(encoding[self.input_name].dtype, torch.float16)
|
||||
self.assertEqual(encoding.input_ids.dtype, torch.long)
|
||||
|
||||
def test_call_pil(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
# Initialize video_processing
|
||||
video_processing = video_processing_class(**self.video_processor_dict)
|
||||
video_inputs = self.video_processor_tester.prepare_video_inputs(equal_resolution=False)
|
||||
|
||||
# Each video is a list of PIL Images
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video[0], Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = video_processing(video_inputs[0], return_tensors="pt")[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_videos.shape), (1, *expected_output_video_shape))
|
||||
|
||||
# Test batched
|
||||
encoded_videos = video_processing(video_inputs, return_tensors="pt")[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
||||
self.assertEqual(
|
||||
tuple(encoded_videos.shape), (self.video_processor_tester.batch_size, *expected_output_video_shape)
|
||||
)
|
||||
|
||||
def test_call_numpy(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
# Initialize video_processing
|
||||
video_processing = video_processing_class(**self.video_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.video_processor_tester.prepare_video_inputs(
|
||||
equal_resolution=False, return_tensors="np"
|
||||
)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = video_processing(video_inputs[0], return_tensors="pt")[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_videos.shape), (1, *expected_output_video_shape))
|
||||
|
||||
# Test batched
|
||||
encoded_videos = video_processing(video_inputs, return_tensors="pt")[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
||||
self.assertEqual(
|
||||
tuple(encoded_videos.shape), (self.video_processor_tester.batch_size, *expected_output_video_shape)
|
||||
)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
# Initialize video_processing
|
||||
video_processing = video_processing_class(**self.video_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
video_inputs = self.video_processor_tester.prepare_video_inputs(
|
||||
equal_resolution=False, return_tensors="torch"
|
||||
)
|
||||
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = video_processing(video_inputs[0], return_tensors="pt")[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_videos.shape), (1, *expected_output_video_shape))
|
||||
|
||||
# Test batched
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
||||
encoded_videos = video_processing(video_inputs, return_tensors="pt")[self.input_name]
|
||||
self.assertEqual(
|
||||
tuple(encoded_videos.shape),
|
||||
(self.video_processor_tester.batch_size, *expected_output_video_shape),
|
||||
)
|
||||
|
||||
def test_nested_input(self):
|
||||
"""Tests that the processor can work with nested list where each video is a list of arrays"""
|
||||
for video_processing_class in self.video_processor_list:
|
||||
video_processing = video_processing_class(**self.video_processor_dict)
|
||||
video_inputs = self.video_processor_tester.prepare_video_inputs(
|
||||
equal_resolution=False, return_tensors="np"
|
||||
)
|
||||
|
||||
# Test not batched input
|
||||
video_inputs = [list(video) for video in video_inputs]
|
||||
encoded_videos = video_processing(video_inputs[0], return_tensors="pt")[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
|
||||
self.assertEqual(tuple(encoded_videos.shape), (1, *expected_output_video_shape))
|
||||
|
||||
# Test batched
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
||||
encoded_videos = video_processing(video_inputs, return_tensors="pt")[self.input_name]
|
||||
self.assertEqual(
|
||||
tuple(encoded_videos.shape),
|
||||
(self.video_processor_tester.batch_size, *expected_output_video_shape),
|
||||
)
|
||||
|
||||
def test_call_numpy_4_channels(self):
|
||||
for video_processing_class in self.video_processor_list:
|
||||
# Test that can process videos which have an arbitrary number of channels
|
||||
# Initialize video_processing
|
||||
video_processor = video_processing_class(**self.video_processor_dict)
|
||||
|
||||
# create random numpy tensors
|
||||
self.video_processor_tester.num_channels = 4
|
||||
video_inputs = self.video_processor_tester.prepare_video_inputs(
|
||||
equal_resolution=False, return_tensors="pil"
|
||||
)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = video_processor(
|
||||
video_inputs[0],
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
)[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape([video_inputs[0]])
|
||||
if video_processor.do_convert_rgb:
|
||||
expected_output_video_shape = list(expected_output_video_shape)
|
||||
expected_output_video_shape[1] = 3
|
||||
self.assertEqual(tuple(encoded_videos.shape), (1, *expected_output_video_shape))
|
||||
|
||||
# Test batched
|
||||
encoded_videos = video_processor(
|
||||
video_inputs,
|
||||
return_tensors="pt",
|
||||
input_data_format="channels_last",
|
||||
image_mean=0,
|
||||
image_std=1,
|
||||
)[self.input_name]
|
||||
expected_output_video_shape = self.video_processor_tester.expected_output_video_shape(video_inputs)
|
||||
if video_processor.do_convert_rgb:
|
||||
expected_output_video_shape = list(expected_output_video_shape)
|
||||
expected_output_video_shape[1] = 3
|
||||
self.assertEqual(
|
||||
tuple(encoded_videos.shape), (self.video_processor_tester.batch_size, *expected_output_video_shape)
|
||||
)
|
||||
|
||||
def test_video_processor_preprocess_arguments(self):
|
||||
is_tested = False
|
||||
|
||||
for video_processing_class in self.video_processor_list:
|
||||
video_processor = video_processing_class(**self.video_processor_dict)
|
||||
|
||||
# validation done by _valid_processor_keys attribute
|
||||
if hasattr(video_processor, "_valid_processor_keys") and hasattr(video_processor, "preprocess"):
|
||||
preprocess_parameter_names = inspect.getfullargspec(video_processor.preprocess).args
|
||||
preprocess_parameter_names.remove("self")
|
||||
preprocess_parameter_names.sort()
|
||||
valid_processor_keys = video_processor._valid_processor_keys
|
||||
valid_processor_keys.sort()
|
||||
self.assertEqual(preprocess_parameter_names, valid_processor_keys)
|
||||
is_tested = True
|
||||
|
||||
# validation done by @filter_out_non_signature_kwargs decorator
|
||||
if hasattr(video_processor.preprocess, "_filter_out_non_signature_kwargs"):
|
||||
if hasattr(self.video_processor_tester, "prepare_video_inputs"):
|
||||
inputs = self.video_processor_tester.prepare_video_inputs()
|
||||
elif hasattr(self.video_processor_tester, "prepare_video_inputs"):
|
||||
inputs = self.video_processor_tester.prepare_video_inputs()
|
||||
else:
|
||||
self.skipTest(reason="No valid input preparation method found")
|
||||
|
||||
with warnings.catch_warnings(record=True) as raised_warnings:
|
||||
warnings.simplefilter("always")
|
||||
video_processor(inputs, extra_argument=True)
|
||||
|
||||
messages = " ".join([str(w.message) for w in raised_warnings])
|
||||
self.assertGreaterEqual(len(raised_warnings), 1)
|
||||
self.assertIn("extra_argument", messages)
|
||||
is_tested = True
|
||||
|
||||
if not is_tested:
|
||||
self.skipTest(reason="No validation found for `preprocess` method")
|
@ -30,7 +30,6 @@ from transformers import is_torch_available, is_vision_available
|
||||
from transformers.image_utils import (
|
||||
ChannelDimension,
|
||||
get_channel_dimension_axis,
|
||||
make_batched_videos,
|
||||
make_flat_list_of_images,
|
||||
make_list_of_images,
|
||||
make_nested_list_of_images,
|
||||
@ -396,133 +395,6 @@ class ImageFeatureExtractionTester(unittest.TestCase):
|
||||
self.assertEqual(len(images_list[0]), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
|
||||
|
||||
def test_make_batched_videos_pil(self):
|
||||
# Test a single image is converted to a list of 1 video with 1 frame
|
||||
pil_image = get_random_image(16, 32)
|
||||
videos_list = make_batched_videos(pil_image)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertEqual(len(videos_list[0]), 1)
|
||||
self.assertIsInstance(videos_list[0][0], PIL.Image.Image)
|
||||
|
||||
# Test a list of images is converted to a list of 1 video
|
||||
images = [get_random_image(16, 32) for _ in range(4)]
|
||||
videos_list = make_batched_videos(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertEqual(len(videos_list), 1)
|
||||
self.assertEqual(len(videos_list[0]), 4)
|
||||
self.assertIsInstance(videos_list[0][0], PIL.Image.Image)
|
||||
|
||||
# Test a nested list of images is not modified
|
||||
images = [[get_random_image(16, 32) for _ in range(2)] for _ in range(2)]
|
||||
videos_list = make_nested_list_of_images(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertEqual(len(videos_list), 2)
|
||||
self.assertEqual(len(videos_list[0]), 2)
|
||||
self.assertIsInstance(videos_list[0][0], PIL.Image.Image)
|
||||
|
||||
def test_make_batched_videos_numpy(self):
|
||||
# Test a single image is converted to a list of 1 video with 1 frame
|
||||
images = np.random.randint(0, 256, (16, 32, 3))
|
||||
videos_list = make_batched_videos(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertEqual(len(videos_list), 1)
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], images))
|
||||
|
||||
# Test a 4d array of images is converted to a list of 1 video
|
||||
images = np.random.randint(0, 256, (4, 16, 32, 3))
|
||||
videos_list = make_batched_videos(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertIsInstance(videos_list[0][0], np.ndarray)
|
||||
self.assertEqual(len(videos_list), 1)
|
||||
self.assertEqual(len(videos_list[0]), 4)
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], images[0]))
|
||||
|
||||
# Test a list of images is converted to a list of videos
|
||||
images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
|
||||
videos_list = make_batched_videos(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertEqual(len(videos_list), 1)
|
||||
self.assertEqual(len(videos_list[0]), 4)
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], images[0]))
|
||||
|
||||
# Test a nested list of images is left unchanged
|
||||
images = [[np.random.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
|
||||
videos_list = make_batched_videos(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertEqual(len(videos_list), 2)
|
||||
self.assertEqual(len(videos_list[0]), 2)
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], images[0][0]))
|
||||
|
||||
# Test a list of 4d array images is converted to a list of videos
|
||||
images = [np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
|
||||
videos_list = make_batched_videos(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertIsInstance(videos_list[0][0], np.ndarray)
|
||||
self.assertEqual(len(videos_list), 2)
|
||||
self.assertEqual(len(videos_list[0]), 4)
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], images[0][0]))
|
||||
|
||||
# Test a batch of list of 4d array images is converted to a list of videos
|
||||
images = [[np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)] for _ in range(2)]
|
||||
videos_list = make_batched_videos(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertIsInstance(videos_list[0][0], np.ndarray)
|
||||
self.assertEqual(len(videos_list), 2)
|
||||
self.assertEqual(len(videos_list[0]), 8)
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], images[0][0][0]))
|
||||
|
||||
@require_torch
|
||||
def test_make_batched_videos_torch(self):
|
||||
# Test a single image is converted to a list of 1 video with 1 frame
|
||||
images = torch.randint(0, 256, (16, 32, 3))
|
||||
videos_list = make_batched_videos(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertEqual(len(videos_list[0]), 1)
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], images))
|
||||
|
||||
# Test a 4d tensor of images is converted to a list of 1 video
|
||||
images = torch.randint(0, 256, (4, 16, 32, 3))
|
||||
videos_list = make_batched_videos(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertIsInstance(videos_list[0][0], torch.Tensor)
|
||||
self.assertEqual(len(videos_list), 1)
|
||||
self.assertEqual(len(videos_list[0]), 4)
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], images[0]))
|
||||
|
||||
# Test a list of images is converted to a list of videos
|
||||
images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
|
||||
videos_list = make_batched_videos(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertEqual(len(videos_list), 1)
|
||||
self.assertEqual(len(videos_list[0]), 4)
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], images[0]))
|
||||
|
||||
# Test a nested list of images is left unchanged
|
||||
images = [[torch.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
|
||||
videos_list = make_batched_videos(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertEqual(len(videos_list), 2)
|
||||
self.assertEqual(len(videos_list[0]), 2)
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], images[0][0]))
|
||||
|
||||
# Test a list of 4d tensor images is converted to a list of videos
|
||||
images = [torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
|
||||
videos_list = make_batched_videos(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertIsInstance(videos_list[0][0], torch.Tensor)
|
||||
self.assertEqual(len(videos_list), 2)
|
||||
self.assertEqual(len(videos_list[0]), 4)
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], images[0][0]))
|
||||
|
||||
# Test a batch of list of 4d tensor images is converted to a list of videos
|
||||
images = [[torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)] for _ in range(2)]
|
||||
videos_list = make_batched_videos(images)
|
||||
self.assertIsInstance(videos_list[0], list)
|
||||
self.assertIsInstance(videos_list[0][0], torch.Tensor)
|
||||
self.assertEqual(len(videos_list), 2)
|
||||
self.assertEqual(len(videos_list[0]), 8)
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], images[0][0][0]))
|
||||
|
||||
@require_torch
|
||||
def test_conversion_torch_to_array(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
|
286
tests/utils/test_video_utils.py
Normal file
286
tests/utils/test_video_utils.py
Normal file
@ -0,0 +1,286 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.image_processing_utils import get_size_dict
|
||||
from transformers.image_utils import SizeDict
|
||||
from transformers.processing_utils import VideosKwargs
|
||||
from transformers.testing_utils import (
|
||||
require_av,
|
||||
require_cv2,
|
||||
require_decord,
|
||||
require_torch,
|
||||
require_torchvision,
|
||||
require_vision,
|
||||
)
|
||||
from transformers.video_utils import make_batched_videos
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
from transformers import BaseVideoProcessor
|
||||
from transformers.video_utils import VideoMetadata, load_video
|
||||
|
||||
|
||||
def get_random_video(height, width, return_torch=False):
|
||||
random_frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
video = np.array(([random_frame] * 8))
|
||||
if return_torch:
|
||||
# move channel first
|
||||
return torch.from_numpy(video).permute(0, 3, 1, 2)
|
||||
return video
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torchvision
|
||||
class BaseVideoProcessorTester(unittest.TestCase):
|
||||
"""
|
||||
Tests that the `transforms` can be applied to a 4-dim array directly, i.e. to a whole video.
|
||||
"""
|
||||
|
||||
def test_make_batched_videos_pil(self):
|
||||
# Test a single image is converted to a list of 1 video with 1 frame
|
||||
video = get_random_video(16, 32)
|
||||
pil_image = PIL.Image.fromarray(video[0])
|
||||
videos_list = make_batched_videos(pil_image)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (1, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], np.array(pil_image)))
|
||||
|
||||
# Test a list of videos is converted to a list of 1 video
|
||||
video = get_random_video(16, 32)
|
||||
video = [PIL.Image.fromarray(frame) for frame in video]
|
||||
videos_list = make_batched_videos(video)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video))
|
||||
|
||||
# Test a nested list of videos is not modified
|
||||
video = get_random_video(16, 32)
|
||||
video = [PIL.Image.fromarray(frame) for frame in video]
|
||||
videos = [video, video]
|
||||
videos_list = make_batched_videos(videos)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video))
|
||||
|
||||
def test_make_batched_videos_numpy(self):
|
||||
# Test a single image is converted to a list of 1 video with 1 frame
|
||||
video = get_random_video(16, 32)[0]
|
||||
videos_list = make_batched_videos(video)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (1, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], video))
|
||||
|
||||
# Test a 4d array of videos is converted to a a list of 1 video
|
||||
video = get_random_video(16, 32)
|
||||
videos_list = make_batched_videos(video)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video))
|
||||
|
||||
# Test a list of videos is converted to a list of videos
|
||||
video = get_random_video(16, 32)
|
||||
videos = [video, video]
|
||||
videos_list = make_batched_videos(videos)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video))
|
||||
|
||||
@require_torch
|
||||
def test_make_batched_videos_torch(self):
|
||||
# Test a single image is converted to a list of 1 video with 1 frame
|
||||
video = get_random_video(16, 32)[0]
|
||||
torch_video = torch.from_numpy(video)
|
||||
videos_list = make_batched_videos(torch_video)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (1, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], video))
|
||||
|
||||
# Test a 4d array of videos is converted to a a list of 1 video
|
||||
video = get_random_video(16, 32)
|
||||
torch_video = torch.from_numpy(video)
|
||||
videos_list = make_batched_videos(torch_video)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], torch.Tensor)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video))
|
||||
|
||||
# Test a list of videos is converted to a list of videos
|
||||
video = get_random_video(16, 32)
|
||||
torch_video = torch.from_numpy(video)
|
||||
videos = [torch_video, torch_video]
|
||||
videos_list = make_batched_videos(videos)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], torch.Tensor)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video))
|
||||
|
||||
def test_resize(self):
|
||||
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
|
||||
video = get_random_video(16, 32, return_torch=True)
|
||||
|
||||
# Size can be an int or a tuple of ints.
|
||||
size_dict = SizeDict(**get_size_dict((8, 8), param_name="size"))
|
||||
resized_video = video_processor.resize(video, size=size_dict)
|
||||
self.assertIsInstance(resized_video, torch.Tensor)
|
||||
self.assertEqual(resized_video.shape, (8, 3, 8, 8))
|
||||
|
||||
def test_normalize(self):
|
||||
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
|
||||
array = torch.randn(4, 3, 16, 32)
|
||||
mean = [0.1, 0.5, 0.9]
|
||||
std = [0.2, 0.4, 0.6]
|
||||
|
||||
# mean and std can be passed as lists or NumPy arrays.
|
||||
expected = (array - torch.tensor(mean)[:, None, None]) / torch.tensor(std)[:, None, None]
|
||||
normalized_array = video_processor.normalize(array, mean, std)
|
||||
torch.testing.assert_close(normalized_array, expected)
|
||||
|
||||
def test_center_crop(self):
|
||||
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
|
||||
video = get_random_video(16, 32, return_torch=True)
|
||||
|
||||
# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
|
||||
crop_sizes = [8, (8, 64), 20, (32, 64)]
|
||||
for size in crop_sizes:
|
||||
size_dict = SizeDict(**get_size_dict(size, default_to_square=True, param_name="crop_size"))
|
||||
cropped_video = video_processor.center_crop(video, size_dict)
|
||||
self.assertIsInstance(cropped_video, torch.Tensor)
|
||||
|
||||
expected_size = (size, size) if isinstance(size, int) else size
|
||||
self.assertEqual(cropped_video.shape, (8, 3, *expected_size))
|
||||
|
||||
def test_convert_to_rgb(self):
|
||||
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
|
||||
video = get_random_video(20, 20, return_torch=True)
|
||||
|
||||
rgb_video = video_processor.convert_to_rgb(video[:, :1])
|
||||
self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
|
||||
|
||||
rgb_video = video_processor.convert_to_rgb(torch.cat([video, video[:, :1]], dim=1))
|
||||
self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_av
|
||||
class LoadVideoTester(unittest.TestCase):
|
||||
def test_load_video_url(self):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
)
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3)) # 243 frames is the whole video, no sampling applied
|
||||
|
||||
def test_load_video_local(self):
|
||||
video_file_path = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
||||
)
|
||||
video, _ = load_video(video_file_path)
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3)) # 243 frames is the whole video, no sampling applied
|
||||
|
||||
# FIXME: @raushan, yt-dlp downloading works for for some reason it cannot redirect to out buffer?
|
||||
# @requires_yt_dlp
|
||||
# def test_load_video_youtube(self):
|
||||
# video = load_video("https://www.youtube.com/watch?v=QC8iQqtG0hg")
|
||||
# self.assertEqual(video.shape, (243, 360, 640, 3)) # 243 frames is the whole video, no sampling applied
|
||||
|
||||
@require_decord
|
||||
@require_torchvision
|
||||
@require_cv2
|
||||
def test_load_video_backend_url(self):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
backend="decord",
|
||||
)
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||
|
||||
# Can't use certain backends with url
|
||||
with self.assertRaises(ValueError):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
backend="opencv",
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
backend="torchvision",
|
||||
)
|
||||
|
||||
@require_decord
|
||||
@require_torchvision
|
||||
@require_cv2
|
||||
def test_load_video_backend_local(self):
|
||||
video_file_path = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
||||
)
|
||||
video, metadata = load_video(video_file_path, backend="decord")
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||
self.assertIsInstance(metadata, VideoMetadata)
|
||||
|
||||
video, metadata = load_video(video_file_path, backend="opencv")
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||
self.assertIsInstance(metadata, VideoMetadata)
|
||||
|
||||
video, metadata = load_video(video_file_path, backend="torchvision")
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||
self.assertIsInstance(metadata, VideoMetadata)
|
||||
|
||||
def test_load_video_num_frames(self):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
num_frames=16,
|
||||
)
|
||||
self.assertEqual(video.shape, (16, 360, 640, 3))
|
||||
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
num_frames=22,
|
||||
)
|
||||
self.assertEqual(video.shape, (22, 360, 640, 3))
|
||||
|
||||
def test_load_video_fps(self):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", fps=1
|
||||
)
|
||||
self.assertEqual(video.shape, (9, 360, 640, 3))
|
||||
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", fps=2
|
||||
)
|
||||
self.assertEqual(video.shape, (19, 360, 640, 3))
|
||||
|
||||
# `num_frames` is mutually exclusive with `video_fps`
|
||||
with self.assertRaises(ValueError):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
fps=1,
|
||||
num_frames=10,
|
||||
)
|
@ -618,6 +618,7 @@ ALL_FILE_TYPES = (
|
||||
"tokenization",
|
||||
"processing",
|
||||
"image_processing",
|
||||
"video_processing",
|
||||
"feature_extractor",
|
||||
)
|
||||
|
||||
@ -1133,9 +1134,12 @@ TYPE_TO_FILE_TYPE = {
|
||||
"Processor": "processing",
|
||||
"ImageProcessor": "image_processing",
|
||||
"ImageProcessorFast": "image_processing*_fast", # "*" indicates where to insert the model name before the "_fast" suffix
|
||||
"VideoProcessor": "video_processing",
|
||||
"VideoProcessorInitKwargs": "video_processing",
|
||||
"FastImageProcessorKwargs": "image_processing*_fast",
|
||||
"FeatureExtractor": "feature_extractor",
|
||||
"ProcessorKwargs": "processing",
|
||||
"VideosKwargs": "processing",
|
||||
"ImagesKwargs": "processing",
|
||||
"TextKwargs": "processing",
|
||||
}
|
||||
|
5
utils/test_module/custom_video_processing.py
Normal file
5
utils/test_module/custom_video_processing.py
Normal file
@ -0,0 +1,5 @@
|
||||
from transformers import LlavaOnevisionVideoProcessor
|
||||
|
||||
|
||||
class CustomVideoProcessor(LlavaOnevisionVideoProcessor):
|
||||
pass
|
Loading…
Reference in New Issue
Block a user