diff --git a/docs/source/en/model_doc/vilt.md b/docs/source/en/model_doc/vilt.md index 107271e2c96..ea598cbbe25 100644 --- a/docs/source/en/model_doc/vilt.md +++ b/docs/source/en/model_doc/vilt.md @@ -72,6 +72,11 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi [[autodoc]] ViltImageProcessor - preprocess +## ViltImageProcessorFast + +[[autodoc]] ViltImageProcessorFast + - preprocess + ## ViltProcessor [[autodoc]] ViltProcessor diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 7141cd2e9eb..2ec63a14b12 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -161,7 +161,7 @@ else: ("upernet", ("SegformerImageProcessor",)), ("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("videomae", ("VideoMAEImageProcessor",)), - ("vilt", ("ViltImageProcessor",)), + ("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")), ("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("vit", ("ViTImageProcessor", "ViTImageProcessorFast")), ("vit_hybrid", ("ViTHybridImageProcessor",)), diff --git a/src/transformers/models/vilt/__init__.py b/src/transformers/models/vilt/__init__.py index b70afa64543..4f154e79e7b 100644 --- a/src/transformers/models/vilt/__init__.py +++ b/src/transformers/models/vilt/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from .configuration_vilt import * from .feature_extraction_vilt import * from .image_processing_vilt import * + from .image_processing_vilt_fast import * from .modeling_vilt import * from .processing_vilt import * else: diff --git a/src/transformers/models/vilt/image_processing_vilt_fast.py b/src/transformers/models/vilt/image_processing_vilt_fast.py new file mode 100644 index 00000000000..764e1203bf6 --- /dev/null +++ b/src/transformers/models/vilt/image_processing_vilt_fast.py @@ -0,0 +1,259 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Vilt.""" + +from typing import List, Optional, Union + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + get_max_height_width, + group_images_by_shape, + reorder_images, +) +from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling, SizeDict +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + +# Set maximum size based on the typical aspect ratio of the COCO dataset +MAX_LONGER_EDGE = 1333 +MAX_SHORTER_EDGE = 800 + + +class ViltFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + Args: + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. If `True`, will pad the images in the batch to the largest height and width + in the batch. Padding will be applied to the bottom and right with zeros. + size_divisor (`int`, *optional*, defaults to 32): + The size to make the height and width divisible by. + rescale_factor (`float`, *optional*, defaults to 1/255): + The factor to rescale the image by. + """ + + do_pad: Optional[bool] + size_divisor: Optional[int] + rescale_factor: Optional[float] + + +@auto_docstring +class ViltImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"shortest_edge": 384} + do_resize = True + do_rescale = True + do_normalize = True + size_divisor = 32 + do_pad = True + default_to_square = False + model_input_names = ["pixel_values", "pixel_mask"] + valid_kwargs = ViltFastImageProcessorKwargs + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + size_divisor: Optional[int], + do_pad: bool, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + This method overrides the base class method to include padding and pixel mask generation. + """ + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(stacked_images, size, interpolation, size_divisor) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + + for shape, stacked_images in grouped_images.items(): + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + + # Handle padding if required + data = {} + if do_pad: + pixel_values, pixel_mask = self._pad_batch(processed_images, return_tensors) + data = {"pixel_values": pixel_values, "pixel_mask": pixel_mask} + else: + # If no padding, just return the processed images + if return_tensors == "pt": + processed_images = torch.stack(processed_images) + data = {"pixel_values": processed_images} + + return BatchFeature(data=data, tensor_type=return_tensors) + + def resize( + self, + images: "torch.Tensor", + size: SizeDict, + interpolation: Optional["F.InterpolationMode"] = None, + size_divisor: Optional[int] = None, + ) -> "torch.Tensor": + """ + Resize an image or batch of images to specified size. + + Args: + images (`torch.Tensor`): Image or batch of images to resize. + size (`Dict[str, int]`): Size dictionary with shortest_edge key. + interpolation (`F.InterpolationMode`, *optional*): Interpolation method to use. + size_divisor (`int`, *optional*): Value to ensure height/width are divisible by. + + Returns: + `torch.Tensor`: Resized image or batch of images. + """ + if interpolation is None: + interpolation = self.resample + + # Resize with aspect ratio preservation + shorter = size.shortest_edge + longer = int(MAX_LONGER_EDGE / MAX_SHORTER_EDGE * shorter) + + heights = images.shape[-2] + widths = images.shape[-1] + + # Determine the new dimensions + if heights < widths: + new_heights = shorter + new_widths = widths * (shorter / heights) + else: + new_heights = heights * (shorter / widths) + new_widths = shorter + + # Check if the longer side exceeds max size + if max(new_heights, new_widths) > longer: + scale = longer / max(new_heights, new_widths) + new_heights = new_heights * scale + new_widths = new_widths * scale + + new_heights = int(new_heights + 0.5) + new_widths = int(new_widths + 0.5) + + # Make dimensions divisible by size_divisor + if size_divisor is not None: + new_heights = new_heights // size_divisor * size_divisor + new_widths = new_widths // size_divisor * size_divisor + + # Resize the image + return F.resize(images, [new_heights, new_widths], interpolation=interpolation) + + def _pad_batch( + self, + images: list["torch.Tensor"], + return_tensors: Optional[Union[str, TensorType]], + ) -> tuple: + """ + Pad a batch of images to the same size based on the maximum dimensions. + + Args: + images (`list[torch.Tensor]`): List of images to pad. + return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. + + Returns: + `tuple`: Tuple containing padded images and pixel masks. + """ + # Calculate global maximum dimensions across all images + max_size = get_max_height_width(images) + + # Group images by shape before padding + grouped_images, grouped_images_index = group_images_by_shape(images) + processed_images = {} + processed_masks = {} + + for shape, stacked_images in grouped_images.items(): + # Create mask template for efficient masking + if return_tensors == "pt" and len(stacked_images) > 0: + device = stacked_images.device + mask_template = torch.zeros(max_size, dtype=torch.int64, device=device) + + original_size = stacked_images.shape[-2:] + needs_padding = original_size[0] != max_size[0] or original_size[1] != max_size[1] + + if needs_padding: + padding_bottom = max_size[0] - original_size[0] + padding_right = max_size[1] - original_size[1] + padding = [0, 0, padding_right, padding_bottom] + + padded_images = F.pad(stacked_images, padding, fill=0) + pixel_mask = mask_template.clone() + pixel_mask[: original_size[0], : original_size[1]].fill_(1) + pixel_masks = pixel_mask.unsqueeze(0).repeat(stacked_images.shape[0], 1, 1) + else: + padded_images = stacked_images + pixel_masks = torch.ones( + (stacked_images.shape[0], max_size[0], max_size[1]), + dtype=torch.int64, + device=stacked_images.device, + ) + + # Store processed group + processed_images[shape] = padded_images + processed_masks[shape] = pixel_masks + + # Reorder images back to original order + padded_images = reorder_images(processed_images, grouped_images_index) + pixel_masks = reorder_images(processed_masks, grouped_images_index) + + # Stack if tensors are requested for final result + if return_tensors == "pt" and padded_images: + padded_images = torch.stack(padded_images) + pixel_masks = torch.stack(pixel_masks) + + return padded_images, pixel_masks + + +__all__ = ["ViltImageProcessorFast"] diff --git a/tests/models/vilt/test_image_processing_vilt.py b/tests/models/vilt/test_image_processing_vilt.py index 74fd5ba8b09..6ad11086310 100644 --- a/tests/models/vilt/test_image_processing_vilt.py +++ b/tests/models/vilt/test_image_processing_vilt.py @@ -16,9 +16,10 @@ import unittest import numpy as np +import torch from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -28,6 +29,9 @@ if is_vision_available(): from transformers import ViltImageProcessor + if is_torchvision_available(): + from transformers import ViltImageProcessorFast + class ViltImageProcessingTester: def __init__( @@ -131,6 +135,7 @@ class ViltImageProcessingTester: @require_vision class ViltImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = ViltImageProcessor if is_vision_available() else None + fast_image_processing_class = ViltImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -141,17 +146,43 @@ class ViltImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "size_divisor")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "size_divisor")) + self.assertTrue(hasattr(image_processing, "do_pad")) + self.assertTrue(hasattr(image_processing, "resample")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "model_input_names")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"shortest_edge": 30}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 30}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42) - self.assertEqual(image_processor.size, {"shortest_edge": 42}) + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + + def test_slow_fast_equivalence(self): + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + + image_processor_slow = self.image_processing_class(**self.image_processor_dict, do_pad=True) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, do_pad=True) + + slow_outputs = image_processor_slow(image_inputs, return_tensors="pt") + slow_pixel_values = slow_outputs.pixel_values + slow_pixel_mask = slow_outputs.pixel_mask + + fast_outputs = image_processor_fast(image_inputs, return_tensors="pt") + fast_pixel_values = fast_outputs.pixel_values + fast_pixel_mask = fast_outputs.pixel_mask + + self.assertEqual(slow_pixel_values.shape, fast_pixel_values.shape) + self.assertTrue(torch.allclose(slow_pixel_values, fast_pixel_values, atol=1e-2)) + + self.assertEqual(slow_pixel_mask.shape, fast_pixel_mask.shape) + self.assertTrue(torch.equal(slow_pixel_mask, fast_pixel_mask))