Add Idefics2/3 and SmolVLM Fast image processors + improvements for fast image processors (#38157)

* add working idefics2 fast and improvements for fast nested images processing

* add fast image processors idefics 3 and smolvlm

* cleanup tests

* fic doc idefics2

* PR review and fix issues after merge

* Force providing disable_grouping to group_images_by_shape

* simplify group_images_by_shape

* fix modular

* Fix nits after review
This commit is contained in:
Yoni Gozlan 2025-06-23 10:17:25 -04:00 committed by GitHub
parent 1a96127e46
commit d29482cc91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
61 changed files with 2023 additions and 425 deletions

View File

@ -162,7 +162,7 @@ To load and run a model using Flash Attention-2, simply change the code snippet
```diff
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
+ torch_dtype=torch.float16,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_2",
).to(device)
```
@ -184,7 +184,7 @@ Quantizing a model is as simple as passing a `quantization_config` to the model.
+ )
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
+ torch_dtype=torch.float16,
+ torch_dtype=torch.float16,
+ quantization_config=quantization_config,
).to(device)
```
@ -218,7 +218,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] Idefics2ImageProcessor
- preprocess
## Idefics2ImageProcessorFast
[[autodoc]] Idefics2ImageProcessorFast
- preprocess
## Idefics2Processor
[[autodoc]] Idefics2Processor
- __call__
- __call__

View File

@ -80,6 +80,9 @@ This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts)
[[autodoc]] Idefics3ImageProcessor
- preprocess
## Idefics3ImageProcessorFast
[[autodoc]] Idefics3ImageProcessorFast
- preprocess
## Idefics3Processor
[[autodoc]] Idefics3Processor

View File

@ -32,7 +32,7 @@ SmolVLM2 is an adaptation of the Idefics3 model with two main differences:
Input images are processed either by upsampling (if resizing is enabled) or at their original resolution. The resizing behavior depends on two parameters: do_resize and size.
Videos should not be upsampled.
Videos should not be upsampled.
If `do_resize` is set to `True`, the model resizes images so that the longest edge is 4*512 pixels by default.
The default resizing behavior can be customized by passing a dictionary to the `size` parameter. For example, `{"longest_edge": 4 * 512}` is the default, but you can change it to a different value if needed.
@ -192,11 +192,14 @@ print(generated_texts[0])
[[autodoc]] SmolVLMForConditionalGeneration
- forward
## SmolVLMImageProcessor
[[autodoc]] SmolVLMImageProcessor
- preprocess
## SmolVLMImageProcessorFast
[[autodoc]] SmolVLMImageProcessorFast
- preprocess
## SmolVLMVideoProcessor
[[autodoc]] SmolVLMVideoProcessor
- preprocess

View File

@ -396,7 +396,7 @@ def add_fast_image_processor_file(
content_header = get_fast_image_processing_content_header(content_base_file)
content_base_file = (
f"@auto_docstring(\n"
f"@auto_docstring\n"
f"class {fast_image_processor_name}(BaseImageProcessorFast):\n"
" # This generated class can be used as a starting point for the fast image processor.\n"
" # if the image processor is only used for simple augmentations, such as resizing, center cropping, rescaling, or normalizing,\n"

View File

@ -184,6 +184,7 @@ class DefaultFastImageProcessorKwargs(TypedDict, total=False):
data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]]
device: Optional["torch.device"]
disable_grouping: Optional[bool]
@auto_docstring
@ -480,18 +481,35 @@ class BaseImageProcessorFast(BaseImageProcessor):
) -> list["torch.Tensor"]:
"""
Prepare the input images for processing.
Args:
images (`ImageInput`):
The input images to process.
do_convert_rgb (`bool`, *optional*):
Whether to convert the images to RGB.
input_data_format (`str` or `ChannelDimension`, *optional*):
The input data format of the images.
device (`torch.device`, *optional*):
The device to put the processed images on.
Returns:
List[`torch.Tensor`]: The processed images.
"""
# Get structured images (potentially nested)
images = self._prepare_images_structure(images)
process_image_fn = partial(
self._process_image,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
process_image_partial = partial(
self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# todo: yoni - check if we can parallelize this efficiently
processed_images = []
for image in images:
processed_images.append(process_image_fn(image))
# Check if we have nested structure, assuming the nesting is consistent
has_nested_structure = len(images) > 0 and isinstance(images[0], (list, tuple))
if has_nested_structure:
processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
else:
processed_images = [process_image_partial(img) for img in images]
return processed_images
@ -621,11 +639,12 @@ class BaseImageProcessorFast(BaseImageProcessor):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -635,7 +654,7 @@ class BaseImageProcessorFast(BaseImageProcessor):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
@ -656,47 +675,3 @@ class BaseImageProcessorFast(BaseImageProcessor):
encoder_dict.pop("_valid_processor_keys", None)
encoder_dict.pop("_valid_kwargs_names", None)
return encoder_dict
class SemanticSegmentationMixin:
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
"""
Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`MobileNetV2ForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`list[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
predictions will not be resized.
Returns:
semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
# if is_torch_tensor(target_sizes):
# target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from collections.abc import Collection, Iterable
from math import ceil
from typing import Optional, Union
@ -841,37 +842,128 @@ def _cast_tensor_to_float(x):
return x.float()
def _group_images_by_shape(nested_images, is_nested: bool = False):
"""Helper function to flatten a single level of nested image structures and group by shape."""
grouped_images = defaultdict(list)
grouped_images_index = {}
nested_images = [nested_images] if not is_nested else nested_images
for i, sublist in enumerate(nested_images):
for j, image in enumerate(sublist):
key = (i, j) if is_nested else j
shape = image.shape[1:]
grouped_images[shape].append(image)
grouped_images_index[key] = (shape, len(grouped_images[shape]) - 1)
return grouped_images, grouped_images_index
def _reconstruct_nested_structure(indices, processed_images):
"""Helper function to reconstruct a single level nested structure."""
# Find the maximum outer index
max_outer_idx = max(idx[0] for idx in indices.keys())
# Create the outer list
result = [None] * (max_outer_idx + 1)
# Group indices by outer index
nested_indices = defaultdict(list)
for i, j in indices.keys():
nested_indices[i].append(j)
for i in range(max_outer_idx + 1):
if i in nested_indices:
inner_max_idx = max(nested_indices[i])
inner_list = [None] * (inner_max_idx + 1)
for j in range(inner_max_idx + 1):
if (i, j) in indices:
shape, idx = indices[(i, j)]
inner_list[j] = processed_images[shape][idx]
result[i] = inner_list
return result
def group_images_by_shape(
images: list["torch.Tensor"],
) -> tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[int, tuple[tuple[int, int], int]]]:
images: Union[list["torch.Tensor"], "torch.Tensor"],
disable_grouping: bool,
is_nested: bool = False,
) -> tuple[
dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]
]:
"""
Groups images by shape.
Returns a dictionary with the shape as key and a list of images with that shape as value,
and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value.
The function supports both flat lists of tensors and nested structures.
The input must be either all flat or all nested, not a mix of both.
Args:
images (Union[list["torch.Tensor"], "torch.Tensor"]):
A list of images or a single tensor
disable_grouping (bool):
Whether to disable grouping. If None, will be set to True if the images are on CPU, and False otherwise.
This choice is based on empirical observations, as detailed here: https://github.com/huggingface/transformers/pull/38157
is_nested (bool, *optional*, defaults to False):
Whether the images are nested.
Returns:
tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]]:
- A dictionary with shape as key and list of images with that shape as value
- A dictionary mapping original indices to (shape, index) tuples
"""
grouped_images = {}
grouped_images_index = {}
for i, image in enumerate(images):
shape = image.shape[1:]
if shape not in grouped_images:
grouped_images[shape] = []
grouped_images[shape].append(image)
grouped_images_index[i] = (shape, len(grouped_images[shape]) - 1)
# stack images with the same shape
grouped_images = {shape: torch.stack(images, dim=0) for shape, images in grouped_images.items()}
# If disable grouping is not explicitely provided, we favor disabling it if the images are on CPU, and enabling it otherwise.
if disable_grouping is None:
device = images[0][0].device if is_nested else images[0].device
disable_grouping = device == "cpu"
if disable_grouping:
if is_nested:
return {(i, j): images[i][j].unsqueeze(0) for i in range(len(images)) for j in range(len(images[i]))}, {
(i, j): ((i, j), 0) for i in range(len(images)) for j in range(len(images[i]))
}
else:
return {i: images[i].unsqueeze(0) for i in range(len(images))}, {i: (i, 0) for i in range(len(images))}
# Handle single level nested structure
grouped_images, grouped_images_index = _group_images_by_shape(images, is_nested)
# Stack images with the same shape
grouped_images = {shape: torch.stack(images_list, dim=0) for shape, images_list in grouped_images.items()}
return grouped_images, grouped_images_index
def reorder_images(
processed_images: dict[tuple[int, int], "torch.Tensor"], grouped_images_index: dict[int, tuple[int, int]]
) -> list["torch.Tensor"]:
processed_images: dict[tuple[int, int], "torch.Tensor"],
grouped_images_index: dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]],
is_nested: bool = False,
) -> Union[list["torch.Tensor"], "torch.Tensor"]:
"""
Reconstructs a list of images in the original order.
Reconstructs images in the original order, preserving the original structure (nested or not).
The input structure is either all flat or all nested.
Args:
processed_images (dict[tuple[int, int], "torch.Tensor"]):
Dictionary mapping shapes to batched processed images.
grouped_images_index (dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]):
Dictionary mapping original indices to (shape, index) tuples.
is_nested (bool, *optional*, defaults to False):
Whether the images are nested. Cannot be infered from the input, as some processing functions outputs nested images.
even with non nested images,e.g functions splitting images into patches. We thus can't deduce is_nested from the input.
Returns:
Union[list["torch.Tensor"], "torch.Tensor"]:
Images in the original structure.
"""
return [
processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
for i in range(len(grouped_images_index))
]
if not is_nested:
return [
processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
for i in range(len(grouped_images_index))
]
return _reconstruct_nested_structure(grouped_images_index, processed_images)
class NumpyToTensor:

View File

@ -95,8 +95,8 @@ else:
("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("hiera", ("BitImageProcessor", "BitImageProcessorFast")),
("idefics", ("IdeficsImageProcessor",)),
("idefics2", ("Idefics2ImageProcessor",)),
("idefics3", ("Idefics3ImageProcessor",)),
("idefics2", ("Idefics2ImageProcessor", "Idefics2ImageProcessorFast")),
("idefics3", ("Idefics3ImageProcessor", "Idefics3ImageProcessorFast")),
("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
("imagegpt", ("ImageGPTImageProcessor",)),
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
@ -148,6 +148,7 @@ else:
("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")),
("superglue", ("SuperGlueImageProcessor",)),
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),

View File

@ -27,13 +27,7 @@ from ...feature_extraction_utils import FeatureExtractionMixin
from ...image_processing_utils import ImageProcessingMixin
from ...processing_utils import ProcessorMixin
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
from ...utils import (
FEATURE_EXTRACTOR_NAME,
PROCESSOR_NAME,
VIDEO_PROCESSOR_NAME,
cached_file,
logging,
)
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, VIDEO_PROCESSOR_NAME, cached_file, logging
from ...video_processing_utils import BaseVideoProcessor
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
@ -118,6 +112,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("shieldgemma2", "ShieldGemma2Processor"),
("siglip", "SiglipProcessor"),
("siglip2", "Siglip2Processor"),
("smolvlm", "SmolVLMProcessor"),
("speech_to_text", "Speech2TextProcessor"),
("speech_to_text_2", "Speech2Text2Processor"),
("speecht5", "SpeechT5Processor"),

View File

@ -105,6 +105,7 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
@ -112,7 +113,7 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
images = self.reduce_label(images)
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -122,7 +123,7 @@ class BeitImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:

View File

@ -223,6 +223,7 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
images: list["torch.Tensor"],
constant_values: Union[float, Iterable[float]] = 0,
return_pixel_mask: bool = True,
disable_grouping: Optional[bool] = False,
) -> tuple:
"""
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
@ -235,6 +236,8 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
The value to use for the padding if `mode` is `"constant"`.
return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether to return a pixel mask.
disable_grouping (`bool`, *optional*, defaults to `False`):
Whether to disable grouping of images by size.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
@ -245,7 +248,7 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
"""
pad_size = get_max_height_width(images)
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_images_grouped = {}
processed_masks_grouped = {}
for shape, stacked_images in grouped_images.items():
@ -283,11 +286,12 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -299,7 +303,7 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
@ -314,7 +318,9 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
data = {}
if do_pad:
processed_images, processed_masks = self.pad(processed_images, return_pixel_mask=True)
processed_images, processed_masks = self.pad(
processed_images, return_pixel_mask=True, disable_grouping=disable_grouping
)
processed_masks = torch.stack(processed_masks, dim=0) if return_tensors else processed_masks
data["pixel_mask"] = processed_masks

View File

@ -19,11 +19,7 @@ from typing import Optional, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
get_resize_output_image_size,
resize,
to_channel_dimension_format,
)
from ...image_transforms import get_resize_output_image_size, resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
ImageInput,

View File

@ -153,10 +153,11 @@ class ConvNextImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -168,7 +169,7 @@ class ConvNextImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:

View File

@ -17,17 +17,8 @@
from typing import TYPE_CHECKING, Optional, Union
from ...image_processing_base import BatchFeature
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
PILImageResampling,
SizeDict,
)
from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling, SizeDict
from ...utils import (
TensorType,
auto_docstring,
@ -85,10 +76,11 @@ class DepthProImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
# Group images by size for batched scaling
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize

View File

@ -16,19 +16,9 @@
from typing import Optional, Union
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
BatchFeature,
DefaultFastImageProcessorKwargs,
)
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs
from ...image_transforms import group_images_by_shape, reorder_images
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
SizeDict,
)
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageInput, PILImageResampling, SizeDict
from ...processing_utils import Unpack
from ...utils import (
TensorType,
@ -230,11 +220,12 @@ class DonutImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_align_long_axis:
@ -254,7 +245,7 @@ class DonutImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:

View File

@ -176,18 +176,19 @@ class DPTImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
return_tensors: Optional[Union[str, TensorType]],
keep_aspect_ratio: bool,
ensure_multiple_of: Optional[int],
do_pad: bool,
size_divisor: Optional[int],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
if do_reduce_labels:
images = self.reduce_label(images)
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -203,7 +204,7 @@ class DPTImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:

View File

@ -224,18 +224,19 @@ class DPTImageProcessorFast(BeitImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
return_tensors: Optional[Union[str, TensorType]],
keep_aspect_ratio: bool,
ensure_multiple_of: Optional[int],
do_pad: bool,
size_divisor: Optional[int],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
if do_reduce_labels:
images = self.reduce_label(images)
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -251,7 +252,7 @@ class DPTImageProcessorFast(BeitImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:

View File

@ -17,19 +17,9 @@
from functools import lru_cache
from typing import Optional, Union
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
BatchFeature,
DefaultFastImageProcessorKwargs,
)
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs
from ...image_transforms import group_images_by_shape, reorder_images
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
SizeDict,
)
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageInput, PILImageResampling, SizeDict
from ...processing_utils import Unpack
from ...utils import (
TensorType,
@ -181,11 +171,12 @@ class EfficientNetImageProcessorFast(BaseImageProcessorFast):
include_top: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -195,7 +186,7 @@ class EfficientNetImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:

View File

@ -367,10 +367,11 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
do_map_pixels: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
) -> "torch.Tensor":
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -380,7 +381,7 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
@ -432,6 +433,7 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
codebook_do_normalize: Optional[bool],
codebook_image_mean: Optional[Union[float, list[float]]],
codebook_image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
@ -448,6 +450,7 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
do_map_pixels=False,
image_mean=image_mean,
image_std=image_std,
disable_grouping=disable_grouping,
return_tensors=return_tensors,
)
data = {
@ -468,6 +471,7 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
do_map_pixels=codebook_do_map_pixels,
image_mean=codebook_image_mean,
image_std=codebook_image_std,
disable_grouping=disable_grouping,
return_tensors=return_tensors,
)
data["codebook_pixel_values"] = codebook_processed_images

View File

@ -25,12 +25,7 @@ from ...image_processing_utils_fast import (
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
SizeDict,
)
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageInput, SizeDict
from ...processing_utils import Unpack
from ...utils import (
TensorType,
@ -205,12 +200,13 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
# Group images by size for batched processing
processed_images_grouped = {}
num_crops_grouped = {}
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
for shape_images, stacked_images in grouped_images.items():
if do_pan_and_scan:
pas_images, num_crops = self._process_images_for_pan_and_scan(
@ -224,7 +220,9 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
stacked_images = [stacked_images] + pas_images
# Group images by size for batched resizing (this will typically group thumbnails together and cropped patches together)
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(stacked_images)
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
stacked_images, disable_grouping=disable_grouping
)
for shape, stacked_image_patches in grouped_image_patches.items():
stacked_image_patches = self.resize(
image=stacked_image_patches,
@ -254,7 +252,7 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize

View File

@ -23,13 +23,7 @@ from ...image_processing_utils_fast import (
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ImageInput,
PILImageResampling,
SizeDict,
)
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput, PILImageResampling, SizeDict
from ...processing_utils import Unpack
from ...utils import (
TensorType,
@ -177,10 +171,11 @@ class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
if crop_to_patches:
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_images_grouped = {}
num_patches = {}
for shape, stacked_images in grouped_images.items():
@ -200,7 +195,7 @@ class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
num_patches = [1] * len(images)
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -210,7 +205,7 @@ class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:

View File

@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_idefics2 import *
from .image_processing_idefics2 import *
from .image_processing_idefics2_fast import *
from .modeling_idefics2 import *
from .processing_idefics2 import *
else:

View File

@ -0,0 +1,318 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import torch
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
BatchFeature,
DefaultFastImageProcessorKwargs,
SizeDict,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
make_nested_list_of_images,
)
from ...processing_utils import Unpack
from ...utils import TensorType, auto_docstring, is_torchvision_available, logging
from .image_processing_idefics2 import convert_to_rgb
if is_torchvision_available():
from torchvision.transforms import functional as F
logger = logging.get_logger(__name__)
def get_resize_output_image_size(image: "torch.Tensor", size: SizeDict) -> tuple[int, int]:
"""
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
Args:
image (`torch.Tensor`):
Image to resize.
size (`SizeDict`):
Size of the output image containing the keys "shortest_edge" and "longest_edge".
Returns:
The output size of the image after resizing.
"""
height, width = image.size()[-2:]
min_len = size.shortest_edge
max_len = size.longest_edge
aspect_ratio = width / height
if width >= height and width > max_len:
width = max_len
height = int(width / aspect_ratio)
elif height > width and height > max_len:
height = max_len
width = int(height * aspect_ratio)
height = max(height, min_len)
width = max(width, min_len)
return height, width
def get_max_height_width(images_list: list[list["torch.Tensor"]]) -> tuple[int, int]:
"""
Get the maximum height and width across all images in a batch.
"""
image_sizes = []
for images in images_list:
for image in images:
image_sizes.append(image.size()[-2:])
max_height = max(size[0] for size in image_sizes)
max_width = max(size[1] for size in image_sizes)
return (max_height, max_width)
def make_pixel_mask(image: "torch.Tensor", output_size: tuple[int, int]) -> "torch.Tensor":
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
Args:
image (`torch.Tensor`):
Image to make the pixel mask for.
output_size (`Tuple[int, int]`):
Output size of the mask.
"""
input_height, input_width = image.size()[-2:]
mask = torch.zeros(output_size, dtype=torch.int64, device=image.device)
mask[:input_height, :input_width] = 1
return mask
class Idefics2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
"""
do_image_splitting (`bool`, *optional*, defaults to `False`):
Whether to split the image into a sequence 4 equal sub-images concatenated with the original image.
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad images to the largest height and width in the batch.
"""
do_image_splitting: Optional[bool]
do_pad: Optional[bool]
@auto_docstring
class Idefics2ImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
do_resize = True
do_rescale = True
do_normalize = True
do_pad = True
do_convert_rgb = True
do_image_splitting = False
size = {"shortest_edge": 378, "longest_edge": 980}
model_input_names = ["pixel_values", "pixel_attention_mask"]
valid_kwargs = Idefics2FastImageProcessorKwargs
def convert_to_rgb(self, image: ImageInput) -> ImageInput:
"""
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
as is.
"""
return convert_to_rgb(image)
def resize(
self, image: torch.Tensor, size: SizeDict, interpolation: Optional["F.InterpolationMode"] = None, **kwargs
) -> torch.Tensor:
"""
Resize an image using torchvision's functional resize.
"""
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
if size.shortest_edge and size.longest_edge:
new_size = get_resize_output_image_size(image, size)
elif size.height and size.width:
new_size = (size.height, size.width)
else:
raise ValueError("Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys.")
image = F.resize(image, size=new_size, interpolation=interpolation, **kwargs)
return image
def _prepare_images_structure(
self,
images: ImageInput,
) -> ImageInput:
"""
Prepare a nested images structure for processing.
"""
return make_nested_list_of_images(images)
def split_images(
self,
images: "torch.Tensor",
) -> list["torch.Tensor"]:
"""
Split a batch of images into 4 equal sub-images, and concatenate that sequence with the original image.
"""
height, width = images.size()[-2:]
mid_width = width // 2
mid_height = height // 2
batch_split_images = [
images[..., :mid_height, :mid_width],
images[..., :mid_height, mid_width:],
images[..., mid_height:, :mid_width],
images[..., mid_height:, mid_width:],
images,
]
# transpose the batch dimension to the first dimension
batch_split_images = [[image[i] for image in batch_split_images] for i in range(len(batch_split_images[0]))]
return batch_split_images
def pad(
self, image: "torch.Tensor", padded_size: tuple[int, int], fill: int = 0
) -> tuple["torch.Tensor", "torch.Tensor"]:
"""
Pad an image to the specified size and create the corresponding pixel mask.
"""
original_size = image.shape[-2:]
padding_bottom = padded_size[0] - original_size[0]
padding_right = padded_size[1] - original_size[1]
if padding_bottom < 0 or padding_right < 0:
raise ValueError(
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
f"original size. Got padded size: {padded_size}, original size: {original_size}."
)
# Only pad if necessary
if original_size != padded_size:
# torchvision's pad takes a 4-element tuple for 2D padding: (left, top, right, bottom)
padding = (0, 0, padding_right, padding_bottom)
# Use constant padding to match slow implementation
image = F.pad(image, padding, fill=fill, padding_mode="constant")
# Create pixel mask to match the slow implementation
pixel_mask = torch.zeros(padded_size, dtype=torch.int64, device=image.device)
pixel_mask[: original_size[0], : original_size[1]] = 1
return image, pixel_mask
@auto_docstring
def preprocess(self, images: ImageInput, **kwargs: Unpack[Idefics2FastImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: list[list["torch.Tensor"]],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
do_pad: Optional[bool],
do_image_splitting: Optional[bool],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
"""
Process a batch of images for the model.
"""
grouped_images, grouped_images_index = group_images_by_shape(
images, is_nested=True, disable_grouping=disable_grouping
)
split_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_image_splitting:
stacked_images = self.split_images(stacked_images)
split_images_grouped[shape] = stacked_images
split_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
if do_image_splitting:
# flattenened the doubly nested list to a nested list
for i, group_images in enumerate(split_images):
split_images[i] = [image for sublist in group_images for image in sublist]
# Group images by size for further processing
grouped_images, grouped_images_index = group_images_by_shape(
split_images, is_nested=True, disable_grouping=disable_grouping
)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(stacked_images, size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index, is_nested=True)
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(
resized_images, is_nested=True, disable_grouping=disable_grouping
)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=True)
if do_pad:
# Get max images per batch
max_num_images = max(len(images_) for images_ in processed_images)
max_height, max_width = get_max_height_width(processed_images)
processed_images_padded = torch.zeros(
len(processed_images),
max_num_images,
*(processed_images[0][0].shape[0], max_height, max_width),
device=processed_images[0][0].device,
)
pixel_attention_masks = torch.zeros(
len(processed_images),
max_num_images,
*(max_height, max_width),
device=processed_images[0][0].device,
)
for i, images in enumerate(processed_images):
for j, image in enumerate(images):
processed_images_padded[i, j], pixel_attention_masks[i, j] = self.pad(
image, (max_height, max_width)
)
processed_images = processed_images_padded
if do_pad:
data = {"pixel_values": processed_images, "pixel_attention_mask": pixel_attention_masks}
elif return_tensors == "pt":
data = {"pixel_values": torch.stack([torch.stack(images) for images in processed_images])}
else:
data = {"pixel_values": processed_images}
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["Idefics2ImageProcessorFast"]

View File

@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_idefics3 import *
from .image_processing_idefics3 import *
from .image_processing_idefics3_fast import *
from .modeling_idefics3 import *
from .processing_idefics3 import *
else:

View File

@ -770,6 +770,7 @@ class Idefics3ImageProcessor(BaseImageProcessor):
split_image_array, rows, cols = self.split_image(
image,
max_image_size=max_image_size,
resample=resample,
input_data_format=input_data_format,
)
split_image_arrays.extend(split_image_array)

View File

@ -0,0 +1,507 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Union
import torch
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
BatchFeature,
DefaultFastImageProcessorKwargs,
SizeDict,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
make_nested_list_of_images,
)
from ...processing_utils import Unpack
from ...utils import TensorType, auto_docstring, is_torchvision_available, logging
if is_torchvision_available():
from torchvision.transforms import functional as F
logger = logging.get_logger(__name__)
MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
def _resize_output_size_rescale_to_max_len(
height: int, width: int, min_len: Optional[int] = 1, max_len: Optional[int] = None
) -> tuple[int, int]:
"""
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
min_len (`int`, *optional*, defaults to 1):
Minimum size of the output image.
max_len (`int`, *optional*, defaults to the maximum size of the image):
Maximum size of the output image.
Returns:
The output size of the image after resizing.
"""
max_len = max(height, width) if max_len is None else max_len
aspect_ratio = width / height
if width >= height:
width = max_len
height = int(width / aspect_ratio)
if height % 2 != 0:
height += 1
elif height > width:
height = max_len
width = int(height * aspect_ratio)
if width % 2 != 0:
width += 1
# Avoid resizing to a size smaller than min_len
height = max(height, min_len)
width = max(width, min_len)
return height, width
def _resize_output_size_scale_below_upper_bound(
height: int, width: int, max_len: Optional[dict[str, int]] = None
) -> tuple[int, int]:
"""
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
Defines the maximum dimensions of the image.
Returns:
The output size of the image after resizing.
"""
max_len = max(height, width) if max_len is None else max_len
aspect_ratio = width / height
if width >= height and width > max_len:
width = max_len
height = int(width / aspect_ratio)
elif height > width and height > max_len:
height = max_len
width = int(height * aspect_ratio)
# Avoid resizing to a size smaller than 1
height = max(height, 1)
width = max(width, 1)
return height, width
def get_resize_output_image_size(
image,
resolution_max_side: int,
) -> tuple[int, int]:
"""
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
Args:
image (`torch.Tensor`):
Image to resize.
resolution_max_side (`int`):
The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
input aspect ratio.
Returns:
The output size of the image after resizing.
"""
height, width = image.size()[-2:]
# Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
# Find the output size when scaling the image to be below the MAX_IMAGE_SIZE
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
return height, width
def get_max_height_width(images_list: list[list["torch.Tensor"]]) -> tuple[int, int]:
"""
Get the maximum height and width across all images in a batch.
"""
image_sizes = []
for images in images_list:
for image in images:
image_sizes.append(image.size()[-2:])
max_height = max(size[0] for size in image_sizes)
max_width = max(size[1] for size in image_sizes)
return (max_height, max_width)
def make_pixel_mask(image: "torch.Tensor", output_size: tuple[int, int]) -> "torch.Tensor":
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
Args:
image (`torch.Tensor`):
Image to make the pixel mask for.
output_size (`Tuple[int, int]`):
Output size of the mask.
"""
input_height, input_width = image.size()[-2:]
mask = torch.zeros(output_size, dtype=torch.int64, device=image.device)
mask[:input_height, :input_width] = 1
return mask
class Idefics3FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
"""
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
do_image_splitting (`bool`, *optional*, defaults to `True`):
Whether to split the image into sub-images concatenated with the original image. They are split into patches
such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`):
Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge".
return_row_col_info (`bool`, *optional*, defaults to `False`):
Whether to return the row and column information of the images.
"""
do_pad: Optional[bool]
do_image_splitting: Optional[bool]
max_image_size: Optional[dict[str, int]]
return_row_col_info: Optional[bool]
@auto_docstring
class Idefics3ImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.LANCZOS
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"longest_edge": 4 * 364}
max_image_size = {"longest_edge": 364}
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_image_splitting = True
do_pad = True
return_row_col_info = False
valid_kwargs = Idefics3FastImageProcessorKwargs
def _prepare_images_structure(
self,
images: ImageInput,
) -> ImageInput:
"""
Prepare a nested images structure for processing.
"""
return make_nested_list_of_images(images)
def resize(
self,
image: "torch.Tensor",
size: SizeDict,
interpolation: "F.InterpolationMode" = None,
antialias: bool = True,
**kwargs,
) -> "torch.Tensor":
"""
Resize an image. The longest edge of the image is resized to size.longest_edge, with the shortest edge
resized to keep the input aspect ratio. Can also be used with size.height and size.width.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
antialias (`bool`, *optional*, defaults to `True`):
Whether to use antialiasing when resizing the image.
"""
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
if interpolation == F.InterpolationMode.LANCZOS:
logger.warning_once(
"You have used fast image processor with LANCZOS resample which not yet supported for torch.Tensor. "
"BICUBIC resample will be used as an alternative. Please fall back to slow image processor if you "
"want full consistency with the original model."
)
interpolation = F.InterpolationMode.BICUBIC
if size.longest_edge:
size = get_resize_output_image_size(image, resolution_max_side=size.longest_edge)
elif size.height and size.width:
size = (size.height, size.width)
else:
raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
return F.resize(image, size, interpolation=interpolation, antialias=antialias)
def split_images(
self,
images: torch.Tensor,
max_image_size: dict[str, int],
interpolation: "F.InterpolationMode" = None,
):
"""
Split an image into squares of side max_image_size and the original image resized to max_image_size.
That means that a single image becomes a sequence of images.
This is a "trick" to spend more compute on each image with no changes in the vision encoder.
1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
sub-images of the same size each (image_size, image_size). Typically, 364x364.
3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
Args:
images (`torch.Tensor`):
Images to split.
max_image_size (`Dict[str, int]`):
Maximum size of the output image. If the image is larger than this size, it will be split into
patches of this size, and the original image will be concatenated with the patches, resized to max_size.
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
"""
batch_size, num_channels, height, width = images.size()
height_dim, width_dim = 2, 3
max_height = max_width = max_image_size["longest_edge"]
frames = []
if height > max_height or width > max_width:
# Calculate the number of splits
num_splits_h = math.ceil(height / max_height)
num_splits_w = math.ceil(width / max_width)
# Split the images by height, then by width
frames = (
images.unfold(height_dim, size=max_height, step=max_height)
.unfold(width_dim, size=max_width, step=max_width)
.contiguous()
.view(batch_size, num_channels, -1, max_height, max_width)
.permute(0, 2, 1, 3, 4)
) # batch_size x n_frames x num_channels x height x width
# For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
global_image_height, global_image_width = max_height, max_width
images = self.resize(
images, SizeDict(height=global_image_height, width=global_image_width), interpolation=interpolation
)
frames = torch.cat((frames, images.unsqueeze(1)), dim=1)
else:
num_splits_h, num_splits_w = 0, 0
frames = images.unsqueeze(1)
num_splits_h = [num_splits_h] * batch_size
num_splits_w = [num_splits_w] * batch_size
return frames, num_splits_h, num_splits_w
def resize_for_vision_encoder(
self,
image: torch.Tensor,
vision_encoder_max_size: int,
interpolation: "F.InterpolationMode" = None,
):
"""
Resize images to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
Args:
image (`torch.Tensor`):
Images to resize.
vision_encoder_max_size (`int`):
Maximum size of the output image. If the image is larger than this size, it will be split into
patches of this size, and the original image will be concatenated with the patches, resized to max_size.
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
"""
height, width = image.size()[-2:]
aspect_ratio = width / height
if width >= height:
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
height = int(width / aspect_ratio)
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
elif height > width:
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
width = int(height * aspect_ratio)
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
new_size = SizeDict(height=height, width=width)
return self.resize(image, size=new_size, interpolation=interpolation)
def pad(
self,
image: torch.Tensor,
padded_size: tuple[int, int],
fill: int = 0,
return_pixel_mask: bool = True,
):
original_size = image.shape[-2:]
padding_bottom = padded_size[0] - original_size[0]
padding_right = padded_size[1] - original_size[1]
if padding_bottom < 0 or padding_right < 0:
raise ValueError(
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
f"original size. Got padded size: {padded_size}, original size: {original_size}."
)
# Only pad if necessary
if original_size != padded_size:
padding = (0, 0, padding_right, padding_bottom)
image = F.pad(image, padding, fill=fill, padding_mode="constant")
# Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
pixel_mask = None
if return_pixel_mask:
pixel_mask = torch.zeros_like(image[..., 0, :, :], dtype=torch.int64)
pixel_mask[: original_size[0], : original_size[1]] = 1
return image, pixel_mask
@auto_docstring
def preprocess(self, images: ImageInput, **kwargs: Unpack[Idefics3FastImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: list[list["torch.Tensor"]],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
do_pad: Optional[bool],
do_image_splitting: Optional[bool],
max_image_size: Optional[dict[str, int]],
return_row_col_info: Optional[bool],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
"""
Process a batch of images for the model.
"""
grouped_images, grouped_images_index = group_images_by_shape(
images, is_nested=True, disable_grouping=disable_grouping
)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(stacked_images, size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index, is_nested=True)
grouped_images, grouped_images_index = group_images_by_shape(
resized_images, is_nested=True, disable_grouping=disable_grouping
)
split_images_grouped = {}
if do_image_splitting:
rows_grouped = {}
cols_grouped = {}
for shape, stacked_images in grouped_images.items():
stacked_images = self.resize_for_vision_encoder(
stacked_images, max_image_size["longest_edge"], interpolation=interpolation
)
stacked_images, rows, cols = self.split_images(
stacked_images, max_image_size=max_image_size, interpolation=interpolation
)
split_images_grouped[shape] = stacked_images
rows_grouped[shape] = rows
cols_grouped[shape] = cols
processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
rows = reorder_images(rows_grouped, grouped_images_index, is_nested=True)
cols = reorder_images(cols_grouped, grouped_images_index, is_nested=True)
# flattenened the doubly nested list to a nested list
for i, group_images in enumerate(processed_images):
processed_images[i] = [image for sublist in group_images for image in sublist]
else:
for shape, stacked_images in grouped_images.items():
# We square the images to max_image_size
stacked_images = self.resize(
image=stacked_images,
size=SizeDict(height=max_image_size["longest_edge"], width=max_image_size["longest_edge"]),
interpolation=interpolation,
)
split_images_grouped[shape] = stacked_images
processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
rows = [[0] * len(images) for images in processed_images]
cols = [[0] * len(images) for images in processed_images]
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(
processed_images, is_nested=True, disable_grouping=disable_grouping
)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=True)
if do_pad:
# Get max images per batch
max_num_images = max(len(images_) for images_ in processed_images)
max_height, max_width = get_max_height_width(processed_images)
processed_images_padded = torch.zeros(
len(processed_images),
max_num_images,
*(processed_images[0][0].shape[0], max_height, max_width),
device=processed_images[0][0].device,
)
pixel_attention_masks = torch.zeros(
len(processed_images),
max_num_images,
*(max_height, max_width),
device=processed_images[0][0].device,
)
for i, images in enumerate(processed_images):
for j, image in enumerate(images):
processed_images_padded[i, j], pixel_attention_masks[i, j] = self.pad(
image, (max_height, max_width)
)
processed_images = processed_images_padded
if do_pad:
data = {"pixel_values": processed_images, "pixel_attention_mask": pixel_attention_masks}
elif return_tensors == "pt":
data = {"pixel_values": torch.stack([torch.stack(images) for images in processed_images])}
else:
data = {"pixel_values": processed_images}
# This is needed for generating correct text inputs in the processor - we don't pad to the max number of images
encoding = BatchFeature(data=data, tensor_type=return_tensors)
if return_row_col_info:
encoding["rows"] = rows
encoding["cols"] = cols
return encoding
def to_dict(self):
encoder_dict = super().to_dict()
encoder_dict.pop("_valid_processor_keys", None)
encoder_dict.pop("return_row_col_info", None)
return encoder_dict
__all__ = ["Idefics3ImageProcessorFast"]

View File

@ -91,6 +91,7 @@ class LayoutLMv2ImageProcessorFast(BaseImageProcessorFast):
apply_ocr: bool,
ocr_lang: Optional[str],
tesseract_config: Optional[str],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
@ -111,7 +112,7 @@ class LayoutLMv2ImageProcessorFast(BaseImageProcessorFast):
boxes_batch.append(boxes)
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -121,7 +122,7 @@ class LayoutLMv2ImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# flip color channels from RGB to BGR (as Detectron2 requires this)

View File

@ -16,11 +16,7 @@
from typing import Optional, Union
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
BatchFeature,
DefaultFastImageProcessorKwargs,
)
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs
from ...image_transforms import ChannelDimension, group_images_by_shape, reorder_images
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageInput, PILImageResampling, SizeDict
from ...processing_utils import Unpack
@ -106,6 +102,7 @@ class LayoutLMv3ImageProcessorFast(BaseImageProcessorFast):
ocr_lang: Optional[str],
tesseract_config: Optional[str],
return_tensors: Optional[Union[str, TensorType]],
disable_grouping: Optional[bool],
**kwargs,
) -> BatchFeature:
# Tesseract OCR to get words + normalized bounding boxes
@ -125,7 +122,7 @@ class LayoutLMv3ImageProcessorFast(BaseImageProcessorFast):
boxes_batch.append(boxes)
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -135,7 +132,7 @@ class LayoutLMv3ImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:

View File

@ -393,13 +393,14 @@ class Llama4ImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
possible_resolutions = find_supported_resolutions(max_num_chunks=max_patches, patch_size=size)
possible_resolutions = torch.tensor(possible_resolutions, device=images[0].device)
# process images by batch, grouped by shape
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
grouped_processed_images = {}
grouped_aspect_ratios = {}
for shape, stacked_images in grouped_images.items():

View File

@ -16,9 +16,7 @@
from typing import Optional, Union
from ...image_processing_utils import (
BatchFeature,
)
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
@ -147,10 +145,11 @@ class LlavaImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_pad:
@ -162,7 +161,7 @@ class LlavaImageProcessorFast(BaseImageProcessorFast):
# Group images by size for batched resizing
# Needed in case do_pad is False, or padding returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(padded_images)
grouped_images, grouped_images_index = group_images_by_shape(padded_images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -172,7 +171,7 @@ class LlavaImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:

View File

@ -242,7 +242,9 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
do_pad: bool,
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
processed_images = []
image_sizes = []
@ -271,7 +273,9 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
# Group images by size for batched processing
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
image_patches, disable_grouping=disable_grouping
)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(

View File

@ -248,7 +248,9 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
image_std: Optional[Union[float, list[float]]],
do_pad: bool,
batch_num_images: list[int],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
processed_images = []
image_sizes = []
@ -287,7 +289,9 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
# Group images by size for batched processing
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
image_patches, disable_grouping=disable_grouping
)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(

View File

@ -169,7 +169,9 @@ class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
image_std: Optional[Union[float, list[float]]],
do_pad: bool,
batch_num_images: list[int],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
processed_images = []
image_sizes = []
@ -208,7 +210,9 @@ class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
# Group images by size for batched processing
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
image_patches, disable_grouping=disable_grouping
)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(

View File

@ -96,11 +96,12 @@ class PerceiverImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
@ -112,7 +113,7 @@ class PerceiverImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize

View File

@ -23,11 +23,7 @@ from ...image_processing_utils_fast import (
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
ImageInput,
PILImageResampling,
SizeDict,
)
from ...image_utils import ImageInput, PILImageResampling, SizeDict
from ...processing_utils import Unpack
from ...utils import (
TensorType,
@ -38,9 +34,7 @@ from ...utils import (
is_vision_available,
logging,
)
from .image_processing_pixtral import (
get_resize_output_image_size,
)
from .image_processing_pixtral import get_resize_output_image_size
logger = logging.get_logger(__name__)
@ -164,12 +158,13 @@ class PixtralImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
patch_size = get_size_dict(patch_size, default_to_square=True)
patch_size = SizeDict(**patch_size)
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -181,7 +176,7 @@ class PixtralImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
batch_image_sizes = [grouped_images_index[i][0] for i in range(len(grouped_images_index))]
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():

View File

@ -16,11 +16,7 @@
from typing import Optional, Union
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
BatchFeature,
DefaultFastImageProcessorKwargs,
)
from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs
from ...image_transforms import (
ChannelDimension,
get_resize_output_image_size,
@ -225,11 +221,12 @@ class PoolFormerImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -241,7 +238,7 @@ class PoolFormerImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:

View File

@ -139,6 +139,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
do_convert_rgb: bool,
input_data_format: Optional[Union[str, ChannelDimension]],
device: Optional[Union[str, torch.device]],
disable_grouping: Optional[bool],
):
"""
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
@ -191,7 +192,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
resized_height, resized_width = height, width
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
@ -210,7 +211,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize
@ -270,6 +271,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
device: Optional["torch.device"] = None,
disable_grouping: Optional[bool] = False,
**kwargs,
):
r"""
@ -360,6 +362,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
disable_grouping=disable_grouping,
)
pixel_values.extend(patches)
vision_grid_thws.append(image_grid_thw)
@ -393,6 +396,7 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
disable_grouping=disable_grouping,
)
pixel_values_videos.extend(patches)
vision_grid_thws_videos.append(video_grid_thw)

View File

@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_smolvlm import *
from .image_processing_smolvlm import *
from .image_processing_smolvlm_fast import *
from .modeling_smolvlm import *
from .processing_smolvlm import *
else:

View File

@ -767,6 +767,7 @@ class SmolVLMImageProcessor(BaseImageProcessor):
split_image_array, rows, cols = self.split_image(
image,
max_image_size=max_image_size,
resample=resample,
input_data_format=input_data_format,
)
split_image_arrays.extend(split_image_array)

View File

@ -0,0 +1,497 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/smolvlm/modular_smolvlm.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_smolvlm.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
# Written by Orr Zohar
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Union
import torch
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
SizeDict,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
make_nested_list_of_images,
)
from ...processing_utils import Unpack
from ...utils import TensorType, auto_docstring, is_torchvision_available, logging
if is_torchvision_available():
from torchvision.transforms import functional as F
logger = logging.get_logger(__name__)
class SmolVLMFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
"""
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
do_image_splitting (`bool`, *optional*, defaults to `True`):
Whether to split the image into sub-images concatenated with the original image. They are split into patches
such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`):
Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge".
return_row_col_info (`bool`, *optional*, defaults to `False`):
Whether to return the row and column information of the images.
"""
do_pad: Optional[bool]
do_image_splitting: Optional[bool]
max_image_size: Optional[dict[str, int]]
return_row_col_info: Optional[bool]
MAX_IMAGE_SIZE = 4096 # 4k resolution as absolute maximum
def _resize_output_size_rescale_to_max_len(
height: int, width: int, min_len: Optional[int] = 1, max_len: Optional[int] = None
) -> tuple[int, int]:
"""
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
min_len (`int`, *optional*, defaults to 1):
Minimum size of the output image.
max_len (`int`, *optional*, defaults to the maximum size of the image):
Maximum size of the output image.
Returns:
The output size of the image after resizing.
"""
max_len = max(height, width) if max_len is None else max_len
aspect_ratio = width / height
if width >= height:
width = max_len
height = int(width / aspect_ratio)
if height % 2 != 0:
height += 1
elif height > width:
height = max_len
width = int(height * aspect_ratio)
if width % 2 != 0:
width += 1
# Avoid resizing to a size smaller than min_len
height = max(height, min_len)
width = max(width, min_len)
return height, width
def _resize_output_size_scale_below_upper_bound(
height: int, width: int, max_len: Optional[dict[str, int]] = None
) -> tuple[int, int]:
"""
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
Defines the maximum dimensions of the image.
Returns:
The output size of the image after resizing.
"""
max_len = max(height, width) if max_len is None else max_len
aspect_ratio = width / height
if width >= height and width > max_len:
width = max_len
height = int(width / aspect_ratio)
elif height > width and height > max_len:
height = max_len
width = int(height * aspect_ratio)
# Avoid resizing to a size smaller than 1
height = max(height, 1)
width = max(width, 1)
return height, width
def get_resize_output_image_size(
image,
resolution_max_side: int,
) -> tuple[int, int]:
"""
Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
Args:
image (`torch.Tensor`):
Image to resize.
resolution_max_side (`int`):
The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
input aspect ratio.
Returns:
The output size of the image after resizing.
"""
height, width = image.size()[-2:]
# Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
# Find the output size when scaling the image to be below the MAX_IMAGE_SIZE
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
return height, width
def get_max_height_width(images_list: list[list["torch.Tensor"]]) -> tuple[int, int]:
"""
Get the maximum height and width across all images in a batch.
"""
image_sizes = []
for images in images_list:
for image in images:
image_sizes.append(image.size()[-2:])
max_height = max(size[0] for size in image_sizes)
max_width = max(size[1] for size in image_sizes)
return (max_height, max_width)
@auto_docstring
class SmolVLMImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.LANCZOS
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"longest_edge": 4 * 364}
max_image_size = {"longest_edge": 364}
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_image_splitting = True
do_pad = True
return_row_col_info = False
valid_kwargs = SmolVLMFastImageProcessorKwargs
def _prepare_images_structure(
self,
images: ImageInput,
) -> ImageInput:
"""
Prepare a nested images structure for processing.
"""
return make_nested_list_of_images(images)
def resize(
self,
image: "torch.Tensor",
size: SizeDict,
interpolation: "F.InterpolationMode" = None,
antialias: bool = True,
**kwargs,
) -> "torch.Tensor":
"""
Resize an image. The longest edge of the image is resized to size.longest_edge, with the shortest edge
resized to keep the input aspect ratio. Can also be used with size.height and size.width.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
antialias (`bool`, *optional*, defaults to `True`):
Whether to use antialiasing when resizing the image.
"""
interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
if interpolation == F.InterpolationMode.LANCZOS:
logger.warning_once(
"You have used fast image processor with LANCZOS resample which not yet supported for torch.Tensor. "
"BICUBIC resample will be used as an alternative. Please fall back to slow image processor if you "
"want full consistency with the original model."
)
interpolation = F.InterpolationMode.BICUBIC
if size.longest_edge:
size = get_resize_output_image_size(image, resolution_max_side=size.longest_edge)
elif size.height and size.width:
size = (size.height, size.width)
else:
raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
return F.resize(image, size, interpolation=interpolation, antialias=antialias)
def split_images(
self,
images: torch.Tensor,
max_image_size: dict[str, int],
interpolation: "F.InterpolationMode" = None,
):
"""
Split an image into squares of side max_image_size and the original image resized to max_image_size.
That means that a single image becomes a sequence of images.
This is a "trick" to spend more compute on each image with no changes in the vision encoder.
1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
sub-images of the same size each (image_size, image_size). Typically, 364x364.
3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
Args:
images (`torch.Tensor`):
Images to split.
max_image_size (`Dict[str, int]`):
Maximum size of the output image. If the image is larger than this size, it will be split into
patches of this size, and the original image will be concatenated with the patches, resized to max_size.
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
"""
batch_size, num_channels, height, width = images.size()
height_dim, width_dim = 2, 3
max_height = max_width = max_image_size["longest_edge"]
frames = []
if height > max_height or width > max_width:
# Calculate the number of splits
num_splits_h = math.ceil(height / max_height)
num_splits_w = math.ceil(width / max_width)
# Split the images by height, then by width
frames = (
images.unfold(height_dim, size=max_height, step=max_height)
.unfold(width_dim, size=max_width, step=max_width)
.contiguous()
.view(batch_size, num_channels, -1, max_height, max_width)
.permute(0, 2, 1, 3, 4)
) # batch_size x n_frames x num_channels x height x width
# For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
global_image_height, global_image_width = max_height, max_width
images = self.resize(
images, SizeDict(height=global_image_height, width=global_image_width), interpolation=interpolation
)
frames = torch.cat((frames, images.unsqueeze(1)), dim=1)
else:
num_splits_h, num_splits_w = 0, 0
frames = images.unsqueeze(1)
num_splits_h = [num_splits_h] * batch_size
num_splits_w = [num_splits_w] * batch_size
return frames, num_splits_h, num_splits_w
def resize_for_vision_encoder(
self,
image: torch.Tensor,
vision_encoder_max_size: int,
interpolation: "F.InterpolationMode" = None,
):
"""
Resize images to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
Args:
image (`torch.Tensor`):
Images to resize.
vision_encoder_max_size (`int`):
Maximum size of the output image. If the image is larger than this size, it will be split into
patches of this size, and the original image will be concatenated with the patches, resized to max_size.
interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
`InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
"""
height, width = image.size()[-2:]
aspect_ratio = width / height
if width >= height:
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
height = int(width / aspect_ratio)
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
elif height > width:
height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
width = int(height * aspect_ratio)
width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
new_size = SizeDict(height=height, width=width)
return self.resize(image, size=new_size, interpolation=interpolation)
def pad(
self,
image: torch.Tensor,
padded_size: tuple[int, int],
fill: int = 0,
return_pixel_mask: bool = True,
):
original_size = image.shape[-2:]
padding_bottom = padded_size[0] - original_size[0]
padding_right = padded_size[1] - original_size[1]
if padding_bottom < 0 or padding_right < 0:
raise ValueError(
f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
f"original size. Got padded size: {padded_size}, original size: {original_size}."
)
# Only pad if necessary
if original_size != padded_size:
padding = (0, 0, padding_right, padding_bottom)
image = F.pad(image, padding, fill=fill, padding_mode="constant")
# Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
pixel_mask = None
if return_pixel_mask:
pixel_mask = torch.zeros_like(image[..., 0, :, :], dtype=torch.int64)
pixel_mask[: original_size[0], : original_size[1]] = 1
return image, pixel_mask
@auto_docstring
def preprocess(self, images: ImageInput, **kwargs: Unpack[SmolVLMFastImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: list[list["torch.Tensor"]],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
do_pad: Optional[bool],
do_image_splitting: Optional[bool],
max_image_size: Optional[dict[str, int]],
return_row_col_info: Optional[bool],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
"""
Process a batch of images for the model.
"""
grouped_images, grouped_images_index = group_images_by_shape(
images, is_nested=True, disable_grouping=disable_grouping
)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(stacked_images, size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index, is_nested=True)
grouped_images, grouped_images_index = group_images_by_shape(
resized_images, is_nested=True, disable_grouping=disable_grouping
)
split_images_grouped = {}
if do_image_splitting:
rows_grouped = {}
cols_grouped = {}
for shape, stacked_images in grouped_images.items():
stacked_images = self.resize_for_vision_encoder(
stacked_images, max_image_size["longest_edge"], interpolation=interpolation
)
stacked_images, rows, cols = self.split_images(
stacked_images, max_image_size=max_image_size, interpolation=interpolation
)
split_images_grouped[shape] = stacked_images
rows_grouped[shape] = rows
cols_grouped[shape] = cols
processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
rows = reorder_images(rows_grouped, grouped_images_index, is_nested=True)
cols = reorder_images(cols_grouped, grouped_images_index, is_nested=True)
# flattenened the doubly nested list to a nested list
for i, group_images in enumerate(processed_images):
processed_images[i] = [image for sublist in group_images for image in sublist]
else:
for shape, stacked_images in grouped_images.items():
# We square the images to max_image_size
stacked_images = self.resize(
image=stacked_images,
size=SizeDict(height=max_image_size["longest_edge"], width=max_image_size["longest_edge"]),
interpolation=interpolation,
)
split_images_grouped[shape] = stacked_images
processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
rows = [[0] * len(images) for images in processed_images]
cols = [[0] * len(images) for images in processed_images]
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(
processed_images, is_nested=True, disable_grouping=disable_grouping
)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=True)
if do_pad:
# Get max images per batch
max_num_images = max(len(images_) for images_ in processed_images)
max_height, max_width = get_max_height_width(processed_images)
processed_images_padded = torch.zeros(
len(processed_images),
max_num_images,
*(processed_images[0][0].shape[0], max_height, max_width),
device=processed_images[0][0].device,
)
pixel_attention_masks = torch.zeros(
len(processed_images),
max_num_images,
*(max_height, max_width),
device=processed_images[0][0].device,
)
for i, images in enumerate(processed_images):
for j, image in enumerate(images):
processed_images_padded[i, j], pixel_attention_masks[i, j] = self.pad(
image, (max_height, max_width)
)
processed_images = processed_images_padded
if do_pad:
data = {"pixel_values": processed_images, "pixel_attention_mask": pixel_attention_masks}
elif return_tensors == "pt":
data = {"pixel_values": torch.stack([torch.stack(images) for images in processed_images])}
else:
data = {"pixel_values": processed_images}
# This is needed for generating correct text inputs in the processor - we don't pad to the max number of images
encoding = BatchFeature(data=data, tensor_type=return_tensors)
if return_row_col_info:
encoding["rows"] = rows
encoding["cols"] = cols
return encoding
def to_dict(self):
encoder_dict = super().to_dict()
encoder_dict.pop("_valid_processor_keys", None)
encoder_dict.pop("return_row_col_info", None)
return encoder_dict
__all__ = ["SmolVLMImageProcessorFast"]

View File

@ -34,7 +34,12 @@ from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
from ...utils import (
LossKwargs,
auto_docstring,
can_return_tuple,
logging,
)
from ..auto import AutoModel
from .configuration_smolvlm import SmolVLMConfig, SmolVLMVisionConfig

View File

@ -25,6 +25,7 @@ from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, logging
from ..idefics3.configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
from ..idefics3.image_processing_idefics3 import Idefics3ImageProcessor
from ..idefics3.image_processing_idefics3_fast import Idefics3ImageProcessorFast
from ..idefics3.modeling_idefics3 import (
Idefics3BaseModelOutputWithPast,
Idefics3ForConditionalGeneration,
@ -160,6 +161,10 @@ class SmolVLMImageProcessor(Idefics3ImageProcessor):
pass
class SmolVLMImageProcessorFast(Idefics3ImageProcessorFast):
pass
class SmolVLMBaseModelOutputWithPast(Idefics3BaseModelOutputWithPast):
pass
@ -396,6 +401,7 @@ __all__ = [
"SmolVLMVisionConfig",
"SmolVLMConfig",
"SmolVLMImageProcessor",
"SmolVLMImageProcessorFast",
"SmolVLMForConditionalGeneration",
"SmolVLMPreTrainedModel",
"SmolVLMModel",

View File

@ -100,11 +100,11 @@ class Swin2SRImageProcessorFast(BaseImageProcessorFast):
rescale_factor: float,
do_pad: bool,
pad_size: int,
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
interpolation: Optional["F.InterpolationMode"],
**kwargs,
) -> BatchFeature:
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_image_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_rescale:

View File

@ -93,6 +93,7 @@ class ViltImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
@ -102,7 +103,7 @@ class ViltImageProcessorFast(BaseImageProcessorFast):
This method overrides the base class method to include padding and pixel mask generation.
"""
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
@ -112,7 +113,7 @@ class ViltImageProcessorFast(BaseImageProcessorFast):
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for further processing
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
@ -127,7 +128,9 @@ class ViltImageProcessorFast(BaseImageProcessorFast):
# Handle padding if required
data = {}
if do_pad:
pixel_values, pixel_mask = self._pad_batch(processed_images, return_tensors)
pixel_values, pixel_mask = self._pad_batch(
processed_images, return_tensors, disable_grouping=disable_grouping
)
data = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
else:
# If no padding, just return the processed images
@ -195,6 +198,7 @@ class ViltImageProcessorFast(BaseImageProcessorFast):
self,
images: list["torch.Tensor"],
return_tensors: Optional[Union[str, TensorType]],
disable_grouping: Optional[bool],
) -> tuple:
"""
Pad a batch of images to the same size based on the maximum dimensions.
@ -210,7 +214,7 @@ class ViltImageProcessorFast(BaseImageProcessorFast):
max_size = get_max_height_width(images)
# Group images by shape before padding
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_images = {}
processed_masks = {}

View File

@ -202,10 +202,12 @@ class VitMatteImageProcessorFast(BaseImageProcessorFast):
image_std: Optional[Union[float, list[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
disable_grouping: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature:
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_trimaps, grouped_trimaps_index = group_images_by_shape(trimaps)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
grouped_trimaps, grouped_trimaps_index = group_images_by_shape(trimaps, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape in grouped_images:
stacked_images = grouped_images[shape]

View File

@ -188,11 +188,12 @@ class ZoeDepthImageProcessorFast(BaseImageProcessorFast):
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_rescale:

View File

@ -204,6 +204,15 @@ class ImageProcessorArgs:
"shape": None,
}
disable_grouping = {
"description": """
Whether to disable grouping of images by size to process them individually and not in batches.
If None, will be set to True if the images are on CPU, and False otherwise. This choice is based on
empirical observations, as detailed here: https://github.com/huggingface/transformers/pull/38157
""",
"shape": None,
}
class ModelArgs:
labels = {

View File

@ -298,11 +298,10 @@ class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt")
self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3
self._assert_slow_fast_tensors_equivalence(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values)
self._assert_slow_fast_tensors_equivalence(
image_encoding_slow.labels.float(), image_encoding_fast.labels.float()
)
self.assertTrue(torch.allclose(image_encoding_slow.labels, image_encoding_fast.labels, atol=1e-1))
def test_slow_fast_equivalence_batched(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
@ -324,7 +323,5 @@ class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
self._assert_slow_fast_tensors_equivalence(encoding_slow.labels.float(), encoding_fast.labels.float())

View File

@ -19,14 +19,11 @@ from typing import Optional, Union
import requests
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
@ -124,10 +121,6 @@ class BridgeTowerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "size_divisor"))
def _assertEquivalence(self, a, b):
self.assertTrue(torch.allclose(a, b, atol=1e-1))
self.assertLessEqual(torch.mean(torch.abs(a - b)).item(), 1e-3)
@require_vision
@require_torch
def test_slow_fast_equivalence(self):
@ -146,8 +139,8 @@ class BridgeTowerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
self._assertEquivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
self._assertEquivalence(encoding_slow.pixel_mask.float(), encoding_fast.pixel_mask.float())
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_mask.float(), encoding_fast.pixel_mask.float())
@require_vision
@require_torch
@ -170,5 +163,5 @@ class BridgeTowerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
self._assertEquivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
self._assertEquivalence(encoding_slow.pixel_mask.float(), encoding_fast.pixel_mask.float())
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_mask.float(), encoding_fast.pixel_mask.float())

View File

@ -418,15 +418,8 @@ class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_fast = image_processor_fast(
dummy_image, return_tensors="pt", return_codebook_pixels=True, return_image_mask=True
)
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
self.assertTrue(
torch.allclose(encoding_slow.codebook_pixel_values, encoding_fast.codebook_pixel_values, atol=1e-1)
)
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.codebook_pixel_values - encoding_fast.codebook_pixel_values)).item(),
1e-3,
self._assert_slow_fast_tensors_equivalence(
encoding_slow.codebook_pixel_values, encoding_fast.codebook_pixel_values
)

View File

@ -286,7 +286,4 @@ class Gemma3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
torch.testing.assert_close(encoding_slow.num_crops, encoding_fast.num_crops)
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)

View File

@ -125,10 +125,7 @@ class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches)
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
def test_slow_fast_equivalence_batched_crop_to_patches(self):
# Prepare image inputs so that we have two groups of images with equal resolution with a group of images with
@ -144,10 +141,7 @@ class GotOcr2ProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
torch.testing.assert_close(encoding_slow.num_patches, encoding_fast.num_patches)
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
def test_crop_to_patches(self):
# test slow image processor

View File

@ -1,4 +1,5 @@
# Copyright 2024 HuggingFace Inc.
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,13 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin
@ -28,6 +28,8 @@ if is_vision_available():
from transformers import Idefics2ImageProcessor
if is_torchvision_available():
from transformers import Idefics2ImageProcessorFast
if is_torch_available():
import torch
@ -88,10 +90,6 @@ class Idefics2ImageProcessingTester:
}
def get_expected_values(self, image_inputs, batched=False):
"""
This function computes the expected height and width when providing images to BridgeTowerImageProcessor,
assuming do_resize is set to True with a scalar size and size_divisor.
"""
if not batched:
shortest_edge = self.size["shortest_edge"]
longest_edge = self.size["longest_edge"]
@ -142,11 +140,6 @@ class Idefics2ImageProcessingTester:
numpify=False,
torchify=False,
):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
One can specify whether the images are of the same resolution or not.
"""
assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time"
batch_size = batch_size if batch_size is not None else self.batch_size
@ -162,23 +155,19 @@ class Idefics2ImageProcessingTester:
if equal_resolution:
width = height = max_resolution
else:
# To avoid getting image width/height 0
if size_divisor is not None:
# If `size_divisor` is defined, the image needs to have width/size >= `size_divisor`
min_resolution = max(size_divisor, min_resolution)
width, height = np.random.choice(np.arange(min_resolution, max_resolution), 2)
images.append(np.random.randint(255, size=(num_channels, width, height), dtype=np.uint8))
images_list.append(images)
if not numpify and not torchify:
# PIL expects the channel dimension as last dimension
images_list = [[Image.fromarray(np.moveaxis(image, 0, -1)) for image in images] for images in images_list]
if torchify:
images_list = [[torch.from_numpy(image) for image in images] for images in images_list]
if numpify:
# Numpy images are typically in channels last format
images_list = [[image.transpose(1, 2, 0) for image in images] for images in images_list]
return images_list
@ -188,6 +177,7 @@ class Idefics2ImageProcessingTester:
@require_vision
class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = Idefics2ImageProcessor if is_vision_available() else None
fast_image_processing_class = Idefics2ImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@ -198,22 +188,23 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
def test_call_numpy(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for sample_images in image_inputs:
@ -238,7 +229,7 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processor_dict = self.image_processor_dict
image_processor_dict["image_mean"] = [0.5, 0.5, 0.5, 0.5]
image_processor_dict["image_std"] = [0.5, 0.5, 0.5, 0.5]
image_processing = self.image_processing_class(**image_processor_dict)
image_processing = image_processing_class(**image_processor_dict)
# create random numpy tensors
self.image_processor_tester.num_channels = 4
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
@ -266,7 +257,7 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
def test_call_pil(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for images in image_inputs:
@ -288,7 +279,7 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
def test_call_pytorch(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
@ -308,3 +299,104 @@ class Idefics2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
tuple(encoded_images.shape),
(self.image_processor_tester.batch_size, *expected_output_image_shape),
)
def test_image_splitting(self):
for image_processing_class in self.image_processor_list:
image_processor_dict = self.image_processor_dict.copy()
image_processor_dict["do_image_splitting"] = True
image_processing = image_processing_class(**image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(
equal_resolution=True, torchify=True, num_images=1
)
result = image_processing(image_inputs[0], return_tensors="pt")
self.assertEqual(result.pixel_values.shape[1], 5)
image_processor_dict["do_image_splitting"] = False
image_processing = image_processing_class(**image_processor_dict)
result = image_processing(image_inputs[0], return_tensors="pt")
if len(result.pixel_values.shape) == 5:
self.assertEqual(result.pixel_values.shape[1], 1)
else:
self.assertEqual(result.pixel_values.shape[1], self.image_processor_tester.num_channels)
def test_pixel_attention_mask(self):
for image_processing_class in self.image_processor_list:
image_processor_dict = self.image_processor_dict.copy()
image_processor_dict["do_pad"] = True
image_processing = image_processing_class(**image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
result = image_processing(image_inputs, return_tensors="pt")
self.assertIn("pixel_attention_mask", result)
self.assertEqual(result.pixel_attention_mask.shape[-2:], result.pixel_values.shape[-2:])
image_processor_dict["do_pad"] = False
image_processor_dict["do_image_splitting"] = False
image_processing = image_processing_class(**image_processor_dict)
equal_size_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
result = image_processing(equal_size_inputs, return_tensors="pt")
self.assertNotIn("pixel_attention_mask", result)
def test_convert_rgb(self):
for image_processing_class in self.image_processor_list:
rgba_image = Image.new("RGBA", (100, 100), (255, 0, 0, 128))
# Test with do_convert_rgb=True - this should work for all processors
image_processor_dict = self.image_processor_dict.copy()
image_processor_dict["do_convert_rgb"] = True
image_processing = image_processing_class(**image_processor_dict)
result = image_processing([rgba_image], return_tensors="pt")
self.assertIsNotNone(result.pixel_values)
rgb_image = rgba_image.convert("RGB")
image_processor_dict["do_convert_rgb"] = False
image_processing = image_processing_class(**image_processor_dict)
# Use the RGB image instead of RGBA when do_convert_rgb=False
result = image_processing([rgb_image], return_tensors="pt")
self.assertIsNotNone(result.pixel_values)
# Additional test: verifying proper handling of regular RGB images
rgb_image = Image.new("RGB", (100, 100), (255, 0, 0))
result = image_processing([rgb_image], return_tensors="pt")
self.assertIsNotNone(result.pixel_values)
def test_slow_fast_equivalence_batched(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)
dummy_images = self.image_processor_tester.prepare_image_inputs(
equal_resolution=False, num_images=5, torchify=True
)
# pop some images to have non homogenous batches:
indices_to_pop = [i if np.random.random() < 0.5 else None for i in range(len(dummy_images))]
for i in indices_to_pop:
if i is not None:
dummy_images[i].pop()
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
self._assert_slow_fast_tensors_equivalence(
encoding_slow.pixel_attention_mask.float(), encoding_fast.pixel_attention_mask.float()
)

View File

@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -16,10 +17,11 @@
import unittest
import numpy as np
import requests
from transformers.image_utils import PILImageResampling
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin
@ -29,6 +31,9 @@ if is_vision_available():
from transformers import Idefics3ImageProcessor
if is_torchvision_available():
from transformers import Idefics3ImageProcessorFast
if is_torch_available():
import torch
@ -164,6 +169,7 @@ class Idefics3ImageProcessingTester:
@require_vision
class Idefics3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = Idefics3ImageProcessor if is_vision_available() else None
fast_image_processing_class = Idefics3ImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@ -174,25 +180,26 @@ class Idefics3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
self.assertTrue(hasattr(image_processing, "max_image_size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
self.assertTrue(hasattr(image_processing, "max_image_size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
def test_call_numpy(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for sample_images in image_inputs:
@ -216,7 +223,7 @@ class Idefics3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processor_dict = self.image_processor_dict
image_processing = self.image_processing_class(**image_processor_dict)
image_processing = image_processing_class(**image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
@ -239,7 +246,7 @@ class Idefics3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
def test_call_pil(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for images in image_inputs:
@ -261,7 +268,7 @@ class Idefics3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
def test_call_pytorch(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
@ -281,3 +288,73 @@ class Idefics3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
tuple(encoded_images.shape),
(self.image_processor_tester.batch_size, *expected_output_image_shape),
)
@require_vision
@require_torch
def test_slow_fast_equivalence(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
dummy_image = dummy_image.resize((100, 150))
image_processor_slow = self.image_processing_class(
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
)
image_processor_fast = self.fast_image_processing_class(
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
)
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt", return_row_col_info=True)
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt", return_row_col_info=True)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
self._assert_slow_fast_tensors_equivalence(
encoding_slow.pixel_attention_mask.float(), encoding_fast.pixel_attention_mask.float()
)
self.assertEqual(encoding_slow.rows, encoding_fast.rows)
self.assertEqual(encoding_slow.cols, encoding_fast.cols)
@require_vision
@require_torch
def test_slow_fast_equivalence_batched(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)
dummy_images = self.image_processor_tester.prepare_image_inputs(
equal_resolution=False, num_images=5, torchify=True
)
# pop some images to have non homogenous batches:
indices_to_pop = [i if np.random.random() < 0.5 else None for i in range(len(dummy_images))]
for i in indices_to_pop:
if i is not None:
dummy_images[i].pop()
image_processor_slow = self.image_processing_class(
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
)
image_processor_fast = self.fast_image_processing_class(
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt", return_row_col_info=True)
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt", return_row_col_info=True)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=3e-1)
self._assert_slow_fast_tensors_equivalence(
encoding_slow.pixel_attention_mask.float(), encoding_fast.pixel_attention_mask.float()
)
self.assertEqual(encoding_slow.rows, encoding_fast.rows)
self.assertEqual(encoding_slow.cols, encoding_fast.cols)

View File

@ -15,9 +15,21 @@
import unittest
import requests
from packaging import version
from transformers.testing_utils import require_pytesseract, require_torch, require_vision
from transformers.utils import is_pytesseract_available, is_torch_available, is_torchvision_available
from transformers.testing_utils import (
require_pytesseract,
require_torch,
require_torch_accelerator,
require_vision,
slow,
torch_device,
)
from transformers.utils import (
is_pytesseract_available,
is_torch_available,
is_torchvision_available,
)
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -157,16 +169,8 @@ class LayoutLMv2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
self.assertTrue(
torch.allclose(
encoding_slow.pixel_values.float() / 255, encoding_fast.pixel_values.float() / 255, atol=1e-1
)
)
self.assertLessEqual(
torch.mean(
torch.abs(encoding_slow.pixel_values.float() - encoding_fast.pixel_values.float()) / 255
).item(),
1e-3,
self._assert_slow_fast_tensors_equivalence(
encoding_slow.pixel_values.float() / 255, encoding_fast.pixel_values.float() / 255
)
@require_vision
@ -190,14 +194,28 @@ class LayoutLMv2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
self.assertTrue(
torch.allclose(
encoding_slow.pixel_values.float() / 255, encoding_fast.pixel_values.float() / 255, atol=1e-1
)
self._assert_slow_fast_tensors_equivalence(
encoding_slow.pixel_values.float() / 255, encoding_fast.pixel_values.float() / 255
)
self.assertLessEqual(
torch.mean(
torch.abs(encoding_slow.pixel_values.float() - encoding_fast.pixel_values.float()) / 255
).item(),
1e-3,
# Overriding as we can't use torch.testing.assert_close on int8 tensors
@slow
@require_torch_accelerator
@require_vision
def test_can_compile_fast_image_processor(self):
if self.fast_image_processing_class is None:
self.skipTest("Skipping compilation test as fast image processor is not defined")
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
torch.compiler.reset()
input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8)
image_processor = self.fast_image_processing_class(**self.image_processor_dict)
output_eager = image_processor(input_image, device=torch_device, return_tensors="pt")
image_processor = torch.compile(image_processor, mode="reduce-overhead")
output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(
output_eager.pixel_values.float() / 255, output_compiled.pixel_values.float() / 255
)

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import unittest
import numpy as np
@ -214,29 +213,6 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list)
self.assertEqual(tuple(batch_encoded_images.shape), expected_output_image_shape)
@require_vision
@require_torch
def test_fast_is_faster_than_slow(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping speed test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping speed test as one of the image processors is not defined")
def measure_time(image_processor, image):
start = time.time()
_ = image_processor(image, return_tensors="pt")
return time.time() - start
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
fast_time = measure_time(image_processor_fast, image_inputs_list)
slow_time = measure_time(image_processor_slow, image_inputs_list)
self.assertLessEqual(fast_time, slow_time)
@require_vision
@require_torch
def test_slow_fast_equivalence(self):
@ -255,9 +231,7 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
torch.testing.assert_close(
encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], rtol=100, atol=1e-1
)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0])
@require_vision
@require_torch
@ -282,14 +256,8 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
for i in range(len(encoding_slow.pixel_values)):
self.assertTrue(
torch.allclose(encoding_slow.pixel_values[i][0], encoding_fast.pixel_values[i][0], atol=1e-1)
)
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values[i][0] - encoding_fast.pixel_values[i][0])).item(), 1e-3
)
torch.testing.assert_close(
encoding_slow.pixel_values[0][0], encoding_fast.pixel_values[0][0], rtol=100, atol=1e-1
self._assert_slow_fast_tensors_equivalence(
encoding_slow.pixel_values[i][0], encoding_fast.pixel_values[i][0]
)
@slow
@ -309,8 +277,8 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processor = torch.compile(image_processor, mode="reduce-overhead")
output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
torch.testing.assert_close(
output_eager.pixel_values[0][0], output_compiled.pixel_values[0][0], rtol=1e-4, atol=1e-4
self._assert_slow_fast_tensors_equivalence(
output_eager.pixel_values[0][0], output_compiled.pixel_values[0][0], atol=1e-4, rtol=1e-4, mean_atol=1e-5
)
@unittest.skip(reason="PixtralImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy

View File

@ -362,6 +362,4 @@ class Qwen2VLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
torch.testing.assert_close(
encoding_slow.pixel_values, encoding_fast.pixel_values, rtol=100, atol=1e-2
) # @yoni bit weird that we have such diffs
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)

View File

@ -18,14 +18,11 @@ import unittest
import requests
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
@ -150,7 +147,6 @@ class Siglip2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
def test_call_numpy_4_channels(self):
pass
# increase mean tolerance to 1e-3 -> 2e-3
# Ignore copy
def test_slow_fast_equivalence(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
@ -167,10 +163,7 @@ class Siglip2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-1)
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 2e-3
)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
# increase mean tolerance to 1e-3 -> 2e-3
# Ignore copy
@ -193,7 +186,4 @@ class Siglip2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-1)
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 2e-3
)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)

View File

@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -16,10 +17,11 @@
import unittest
import numpy as np
import requests
from transformers.image_utils import PILImageResampling
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin
@ -29,6 +31,9 @@ if is_vision_available():
from transformers import SmolVLMImageProcessor
if is_torchvision_available():
from transformers import SmolVLMImageProcessorFast
if is_torch_available():
import torch
@ -164,6 +169,7 @@ class SmolVLMImageProcessingTester:
@require_vision
class SmolVLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = SmolVLMImageProcessor if is_vision_available() else None
fast_image_processing_class = SmolVLMImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@ -174,25 +180,26 @@ class SmolVLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
self.assertTrue(hasattr(image_processing, "max_image_size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
self.assertTrue(hasattr(image_processing, "max_image_size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "do_image_splitting"))
def test_call_numpy(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for sample_images in image_inputs:
@ -216,7 +223,7 @@ class SmolVLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processor_dict = self.image_processor_dict
image_processing = self.image_processing_class(**image_processor_dict)
image_processing = image_processing_class(**image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
@ -239,7 +246,7 @@ class SmolVLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
def test_call_pil(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for images in image_inputs:
@ -261,7 +268,7 @@ class SmolVLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
def test_call_pytorch(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
@ -281,3 +288,73 @@ class SmolVLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
tuple(encoded_images.shape),
(self.image_processor_tester.batch_size, *expected_output_image_shape),
)
@require_vision
@require_torch
def test_slow_fast_equivalence(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
dummy_image = dummy_image.resize((100, 150))
image_processor_slow = self.image_processing_class(
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
)
image_processor_fast = self.fast_image_processing_class(
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
)
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt", return_row_col_info=True)
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt", return_row_col_info=True)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
self._assert_slow_fast_tensors_equivalence(
encoding_slow.pixel_attention_mask.float(), encoding_fast.pixel_attention_mask.float()
)
self.assertEqual(encoding_slow.rows, encoding_fast.rows)
self.assertEqual(encoding_slow.cols, encoding_fast.cols)
@require_vision
@require_torch
def test_slow_fast_equivalence_batched(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)
dummy_images = self.image_processor_tester.prepare_image_inputs(
equal_resolution=False, num_images=5, torchify=True
)
# pop some images to have non homogenous batches:
indices_to_pop = [i if np.random.random() < 0.5 else None for i in range(len(dummy_images))]
for i in indices_to_pop:
if i is not None:
dummy_images[i].pop()
image_processor_slow = self.image_processing_class(
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
)
image_processor_fast = self.fast_image_processing_class(
**self.image_processor_dict, resample=PILImageResampling.BICUBIC
)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt", return_row_col_info=True)
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt", return_row_col_info=True)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=3e-1)
self._assert_slow_fast_tensors_equivalence(
encoding_slow.pixel_attention_mask.float(), encoding_fast.pixel_attention_mask.float()
)
self.assertEqual(encoding_slow.rows, encoding_fast.rows)
self.assertEqual(encoding_slow.cols, encoding_fast.cols)

View File

@ -197,7 +197,7 @@ class Swin2SRImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoded_slow = image_processor_slow(image_inputs, return_tensors="pt").pixel_values
encoded_fast = image_processor_fast(image_inputs, return_tensors="pt").pixel_values
encoded_slow = image_processor_slow(image_inputs, return_tensors="pt")
encoded_fast = image_processor_fast(image_inputs, return_tensors="pt")
self.assertTrue(torch.allclose(encoded_slow, encoded_fast, atol=1e-1))
self._assert_slow_fast_tensors_equivalence(encoded_slow.pixel_values, encoded_fast.pixel_values)

View File

@ -312,10 +312,7 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_slow = image_processor_slow(dummy_image, trimaps=dummy_trimap, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, trimaps=dummy_trimap, return_tensors="pt")
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
def test_slow_fast_equivalence_batched(self):
# this only checks on equal resolution, since the slow processor doesn't work otherwise
@ -338,10 +335,7 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_slow = image_processor_slow(dummy_images, trimaps=dummy_trimaps, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, trimaps=dummy_trimaps, return_tensors="pt")
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
@slow
@require_torch_accelerator

View File

@ -162,6 +162,10 @@ class ImageProcessingTestMixin:
self.image_processor_list = image_processor_list
def _assert_slow_fast_tensors_equivalence(self, slow_tensor, fast_tensor, atol=1e-1, rtol=1e-3, mean_atol=5e-3):
torch.testing.assert_close(slow_tensor, fast_tensor, atol=atol, rtol=rtol)
self.assertLessEqual(torch.mean(torch.abs(slow_tensor - fast_tensor)).item(), mean_atol)
@require_vision
@require_torch
def test_slow_fast_equivalence(self):
@ -179,10 +183,7 @@ class ImageProcessingTestMixin:
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-3)
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 5e-3
)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
@require_vision
@require_torch
@ -205,10 +206,7 @@ class ImageProcessingTestMixin:
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
torch.testing.assert_close(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1, rtol=1e-3)
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 5e-3
)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
@require_vision
@require_torch
@ -577,8 +575,10 @@ class ImageProcessingTestMixin:
image_processor = torch.compile(image_processor, mode="reduce-overhead")
output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt")
torch.testing.assert_close(output_eager.pixel_values, output_compiled.pixel_values, rtol=1e-4, atol=1e-4)
print(output_eager.pixel_values.dtype, output_compiled.pixel_values.dtype)
self._assert_slow_fast_tensors_equivalence(
output_eager.pixel_values, output_compiled.pixel_values, atol=1e-4, rtol=1e-4, mean_atol=1e-5
)
class AnnotationFormatTestMixin: