Add Fast Image Processor for vilt (#37304)

* init vilt image processor fast

* Refactor image processor tests to use loop for all processors

* Add ViltImageProcessorFast with PyTorch-based optimized image processing

* Change made automatically by make fixup command

* Change made automatically by make fix-copies command

* Fix type hints in ViltImageProcessorFast for Python compatibility

* Define constants for image resizing based on COCO dataset aspect ratio

* Add missing property initializations to ViltImageProcessorFast

* Extract resize logic into dedicated method in ViltImageProcessorFast

* Extract padding logic into dedicated method

* Implement shape-based image grouping for optimized processing in Vilt

* Update test suite to verify ViltImageProcessorFast attributes

* Move variable declarations to _preprocess method parameters

* Remove unused parameters

* Rename _resize method to resize to override existing function

* Remove whitespace

* Remove unnecessary type check and conversion for stacked_images

* Remove redundant loop and apply padding directly to stacked images

* Refactor pad function to return images and mask as tuple instead of dict

* Add tests comparing padding masks in slow and fast implementations

* Update ViltImageProcessor tests to ensure compatibility between slow and fast implementations

* Replace add_start_docstrings with auto_docstring in ViltImageProcessorFast

* Move docstrings of custom args to ViltFastImageProcessorKwargs

* Use reorder_images function for both masks and images

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
Jinyong Lee 2025-05-14 00:40:53 +09:00 committed by GitHub
parent 8771766a70
commit 342961f669
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 309 additions and 13 deletions

View File

@ -72,6 +72,11 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
[[autodoc]] ViltImageProcessor
- preprocess
## ViltImageProcessorFast
[[autodoc]] ViltImageProcessorFast
- preprocess
## ViltProcessor
[[autodoc]] ViltProcessor

View File

@ -161,7 +161,7 @@ else:
("upernet", ("SegformerImageProcessor",)),
("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("videomae", ("VideoMAEImageProcessor",)),
("vilt", ("ViltImageProcessor",)),
("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")),
("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("vit", ("ViTImageProcessor", "ViTImageProcessorFast")),
("vit_hybrid", ("ViTHybridImageProcessor",)),

View File

@ -21,6 +21,7 @@ if TYPE_CHECKING:
from .configuration_vilt import *
from .feature_extraction_vilt import *
from .image_processing_vilt import *
from .image_processing_vilt_fast import *
from .modeling_vilt import *
from .processing_vilt import *
else:

View File

@ -0,0 +1,259 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for Vilt."""
from typing import List, Optional, Union
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
get_max_height_width,
group_images_by_shape,
reorder_images,
)
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling, SizeDict
from ...utils import (
TensorType,
auto_docstring,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
# Set maximum size based on the typical aspect ratio of the COCO dataset
MAX_LONGER_EDGE = 1333
MAX_SHORTER_EDGE = 800
class ViltFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
"""
Args:
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image. If `True`, will pad the images in the batch to the largest height and width
in the batch. Padding will be applied to the bottom and right with zeros.
size_divisor (`int`, *optional*, defaults to 32):
The size to make the height and width divisible by.
rescale_factor (`float`, *optional*, defaults to 1/255):
The factor to rescale the image by.
"""
do_pad: Optional[bool]
size_divisor: Optional[int]
rescale_factor: Optional[float]
@auto_docstring
class ViltImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"shortest_edge": 384}
do_resize = True
do_rescale = True
do_normalize = True
size_divisor = 32
do_pad = True
default_to_square = False
model_input_names = ["pixel_values", "pixel_mask"]
valid_kwargs = ViltFastImageProcessorKwargs
def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
size_divisor: Optional[int],
do_pad: bool,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
"""
Preprocess an image or batch of images.
This method overrides the base class method to include padding and pixel mask generation.
"""
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(stacked_images, size, interpolation, size_divisor)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for further processing
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
# Handle padding if required
data = {}
if do_pad:
pixel_values, pixel_mask = self._pad_batch(processed_images, return_tensors)
data = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
else:
# If no padding, just return the processed images
if return_tensors == "pt":
processed_images = torch.stack(processed_images)
data = {"pixel_values": processed_images}
return BatchFeature(data=data, tensor_type=return_tensors)
def resize(
self,
images: "torch.Tensor",
size: SizeDict,
interpolation: Optional["F.InterpolationMode"] = None,
size_divisor: Optional[int] = None,
) -> "torch.Tensor":
"""
Resize an image or batch of images to specified size.
Args:
images (`torch.Tensor`): Image or batch of images to resize.
size (`Dict[str, int]`): Size dictionary with shortest_edge key.
interpolation (`F.InterpolationMode`, *optional*): Interpolation method to use.
size_divisor (`int`, *optional*): Value to ensure height/width are divisible by.
Returns:
`torch.Tensor`: Resized image or batch of images.
"""
if interpolation is None:
interpolation = self.resample
# Resize with aspect ratio preservation
shorter = size.shortest_edge
longer = int(MAX_LONGER_EDGE / MAX_SHORTER_EDGE * shorter)
heights = images.shape[-2]
widths = images.shape[-1]
# Determine the new dimensions
if heights < widths:
new_heights = shorter
new_widths = widths * (shorter / heights)
else:
new_heights = heights * (shorter / widths)
new_widths = shorter
# Check if the longer side exceeds max size
if max(new_heights, new_widths) > longer:
scale = longer / max(new_heights, new_widths)
new_heights = new_heights * scale
new_widths = new_widths * scale
new_heights = int(new_heights + 0.5)
new_widths = int(new_widths + 0.5)
# Make dimensions divisible by size_divisor
if size_divisor is not None:
new_heights = new_heights // size_divisor * size_divisor
new_widths = new_widths // size_divisor * size_divisor
# Resize the image
return F.resize(images, [new_heights, new_widths], interpolation=interpolation)
def _pad_batch(
self,
images: list["torch.Tensor"],
return_tensors: Optional[Union[str, TensorType]],
) -> tuple:
"""
Pad a batch of images to the same size based on the maximum dimensions.
Args:
images (`list[torch.Tensor]`): List of images to pad.
return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return.
Returns:
`tuple`: Tuple containing padded images and pixel masks.
"""
# Calculate global maximum dimensions across all images
max_size = get_max_height_width(images)
# Group images by shape before padding
grouped_images, grouped_images_index = group_images_by_shape(images)
processed_images = {}
processed_masks = {}
for shape, stacked_images in grouped_images.items():
# Create mask template for efficient masking
if return_tensors == "pt" and len(stacked_images) > 0:
device = stacked_images.device
mask_template = torch.zeros(max_size, dtype=torch.int64, device=device)
original_size = stacked_images.shape[-2:]
needs_padding = original_size[0] != max_size[0] or original_size[1] != max_size[1]
if needs_padding:
padding_bottom = max_size[0] - original_size[0]
padding_right = max_size[1] - original_size[1]
padding = [0, 0, padding_right, padding_bottom]
padded_images = F.pad(stacked_images, padding, fill=0)
pixel_mask = mask_template.clone()
pixel_mask[: original_size[0], : original_size[1]].fill_(1)
pixel_masks = pixel_mask.unsqueeze(0).repeat(stacked_images.shape[0], 1, 1)
else:
padded_images = stacked_images
pixel_masks = torch.ones(
(stacked_images.shape[0], max_size[0], max_size[1]),
dtype=torch.int64,
device=stacked_images.device,
)
# Store processed group
processed_images[shape] = padded_images
processed_masks[shape] = pixel_masks
# Reorder images back to original order
padded_images = reorder_images(processed_images, grouped_images_index)
pixel_masks = reorder_images(processed_masks, grouped_images_index)
# Stack if tensors are requested for final result
if return_tensors == "pt" and padded_images:
padded_images = torch.stack(padded_images)
pixel_masks = torch.stack(pixel_masks)
return padded_images, pixel_masks
__all__ = ["ViltImageProcessorFast"]

View File

@ -16,9 +16,10 @@
import unittest
import numpy as np
import torch
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -28,6 +29,9 @@ if is_vision_available():
from transformers import ViltImageProcessor
if is_torchvision_available():
from transformers import ViltImageProcessorFast
class ViltImageProcessingTester:
def __init__(
@ -131,6 +135,7 @@ class ViltImageProcessingTester:
@require_vision
class ViltImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = ViltImageProcessor if is_vision_available() else None
fast_image_processing_class = ViltImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@ -141,17 +146,43 @@ class ViltImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "size_divisor"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "size_divisor"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "model_input_names"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 30})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 30})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
def test_slow_fast_equivalence(self):
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
image_processor_slow = self.image_processing_class(**self.image_processor_dict, do_pad=True)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, do_pad=True)
slow_outputs = image_processor_slow(image_inputs, return_tensors="pt")
slow_pixel_values = slow_outputs.pixel_values
slow_pixel_mask = slow_outputs.pixel_mask
fast_outputs = image_processor_fast(image_inputs, return_tensors="pt")
fast_pixel_values = fast_outputs.pixel_values
fast_pixel_mask = fast_outputs.pixel_mask
self.assertEqual(slow_pixel_values.shape, fast_pixel_values.shape)
self.assertTrue(torch.allclose(slow_pixel_values, fast_pixel_values, atol=1e-2))
self.assertEqual(slow_pixel_mask.shape, fast_pixel_mask.shape)
self.assertTrue(torch.equal(slow_pixel_mask, fast_pixel_mask))