diff --git a/docs/source/en/model_doc/poolformer.md b/docs/source/en/model_doc/poolformer.md index bce183706a8..60573162d68 100644 --- a/docs/source/en/model_doc/poolformer.md +++ b/docs/source/en/model_doc/poolformer.md @@ -73,6 +73,11 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] PoolFormerImageProcessor - preprocess +## PoolFormerImageProcessorFast + +[[autodoc]] PoolFormerImageProcessorFast + - preprocess + ## PoolFormerModel [[autodoc]] PoolFormerModel diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 01164f3a20a..d0ac43f7471 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -131,7 +131,7 @@ else: ("phi4_multimodal", "Phi4MultimodalImageProcessorFast"), ("pix2struct", ("Pix2StructImageProcessor",)), ("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")), - ("poolformer", ("PoolFormerImageProcessor",)), + ("poolformer", ("PoolFormerImageProcessor", "PoolFormerImageProcessorFast")), ("prompt_depth_anything", ("PromptDepthAnythingImageProcessor",)), ("pvt", ("PvtImageProcessor", "PvtImageProcessorFast")), ("pvt_v2", ("PvtImageProcessor", "PvtImageProcessorFast")), diff --git a/src/transformers/models/poolformer/__init__.py b/src/transformers/models/poolformer/__init__.py index d0d79274910..2e80bb47c30 100644 --- a/src/transformers/models/poolformer/__init__.py +++ b/src/transformers/models/poolformer/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from .configuration_poolformer import * from .feature_extraction_poolformer import * from .image_processing_poolformer import * + from .image_processing_poolformer_fast import * from .modeling_poolformer import * else: import sys diff --git a/src/transformers/models/poolformer/image_processing_poolformer_fast.py b/src/transformers/models/poolformer/image_processing_poolformer_fast.py new file mode 100644 index 00000000000..37e81a5f5f9 --- /dev/null +++ b/src/transformers/models/poolformer/image_processing_poolformer_fast.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for PoolFormer.""" + +from typing import Optional, Union + +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + BatchFeature, + DefaultFastImageProcessorKwargs, +) +from ...image_transforms import ( + ChannelDimension, + get_resize_output_image_size, + get_size_with_aspect_ratio, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ImageInput, + PILImageResampling, + SizeDict, + get_image_size_for_max_height_width, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class PoolFormerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + crop_pct: Optional[float] + + +@add_start_docstrings( + "Constructs a fast PoolFormer image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + crop_pct (`float`, *optional*, defaults to `self.crop_pct`): + Percentage of the image to crop. Only has an effect if `do_resize` is set to `True`. + """, +) +class PoolFormerImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"shortest_edge": 224} + default_to_square = False + crop_size = {"height": 224, "width": 224} + crop_pct = 0.9 + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + valid_kwargs = PoolFormerFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[PoolFormerFastImageProcessorKwargs]): + super().__init__(**kwargs) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + crop_pct (`float`, *optional*, defaults to `self.crop_pct`): + Percentage of the image to crop. Only has an effect if `do_resize` is set to `True`. + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[PoolFormerFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def resize( + self, + image: "torch.Tensor", + size: SizeDict, + crop_pct: Optional[float] = None, + interpolation: "F.InterpolationMode" = None, + antialias: bool = True, + **kwargs, + ) -> "torch.Tensor": + """ + Resize an image. + + If crop_pct is unset: + - size is `{"height": h, "width": w}`: the image is resized to `(h, w)`. + - size is `{"shortest_edge": s}`: the shortest edge of the image is resized to s whilst maintaining the + aspect ratio. + + if crop_pct is set: + - size is `{"height": h, "width": w}`: the image is resized to `(int(floor(h/crop_pct)), + int(floor(w/crop_pct)))` + - size is `{"height": c, "width": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)` + whilst maintaining the aspect ratio. + - size is `{"shortest_edge": c}`: the shortest edge of the image is resized to `int(floor(c/crop_pct)` + whilst maintaining the aspect ratio. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + crop_pct (`float`, *optional*): + Percentage of the image that will be cropped from the center. If set, the image is resized + resample (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`. + + Returns: + `torch.Tensor`: The resized image. + """ + interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR + if crop_pct is not None: + if size.shortest_edge: + scale_size = int(size.shortest_edge / crop_pct) + elif size.height and size.width: + if size.height == size.width: + scale_size = int(size.height / crop_pct) + else: + scale_size = (int(size.height / crop_pct), int(size.width / crop_pct)) + else: + raise ValueError("Invalid size for resize: {}".format(size)) + + new_size = get_resize_output_image_size( + image, + size=scale_size, + default_to_square=False, + input_data_format=ChannelDimension.FIRST, + ) + else: + if size.shortest_edge and size.longest_edge: + # Resize the image so that the shortest edge or the longest edge is of the given size + # while maintaining the aspect ratio of the original image. + new_size = get_size_with_aspect_ratio( + image.size()[-2:], + size.shortest_edge, + size.longest_edge, + ) + elif size.shortest_edge: + new_size = get_resize_output_image_size( + image, + size=size.shortest_edge, + default_to_square=False, + input_data_format=ChannelDimension.FIRST, + ) + elif size.max_height and size.max_width: + new_size = get_image_size_for_max_height_width(image.size()[-2:], size.max_height, size.max_width) + elif size.height and size.width: + new_size = (size.height, size.width) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got" + f" {size}." + ) + return F.resize(image, new_size, interpolation=interpolation, antialias=antialias) + + def center_crop( + self, + image: "torch.Tensor", + size: SizeDict, + **kwargs, + ) -> "torch.Tensor": + """ + Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along + any edge, the image is padded with 0's and then center cropped. + + Args: + image (`"torch.Tensor"`): + Image to center crop. + size (`Dict[str, int]`): + Size of the output image. + + Returns: + `torch.Tensor`: The center cropped image. + """ + if size.height is None or size.width is None: + raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}") + image_height, image_width = image.shape[-2:] + crop_height, crop_width = size.height, size.width + + if crop_width > image_width or crop_height > image_height: + padding_ltrb = [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, + ] + image = F.pad(image, padding_ltrb, fill=0) # PIL uses fill value 0 + image_height, image_width = image.shape[-2:] + if crop_width == image_width and crop_height == image_height: + return image + + crop_top = int((image_height - crop_height) / 2.0) + crop_left = int((image_width - crop_width) / 2.0) + return F.crop(image, crop_top, crop_left, crop_height, crop_width) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + crop_pct: float, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_images, size=size, crop_pct=crop_pct, interpolation=interpolation + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["PoolFormerImageProcessorFast"] diff --git a/tests/models/poolformer/test_image_processing_poolformer.py b/tests/models/poolformer/test_image_processing_poolformer.py index e5e64d25705..bfaebd92d6b 100644 --- a/tests/models/poolformer/test_image_processing_poolformer.py +++ b/tests/models/poolformer/test_image_processing_poolformer.py @@ -15,7 +15,7 @@ import unittest from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -23,6 +23,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im if is_vision_available(): from transformers import PoolFormerImageProcessor + if is_torchvision_available(): + from transformers import PoolFormerImageProcessorFast + class PoolFormerImageProcessingTester: def __init__( @@ -85,6 +88,7 @@ class PoolFormerImageProcessingTester: @require_vision class PoolFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = PoolFormerImageProcessor if is_vision_available() else None + fast_image_processing_class = PoolFormerImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -95,19 +99,29 @@ class PoolFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase) return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize_and_center_crop")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "crop_pct")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize_and_center_crop")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "crop_pct")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"shortest_edge": 30}) - self.assertEqual(image_processor.crop_size, {"height": 30, "width": 30}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 30}) + self.assertEqual(image_processor.crop_size, {"height": 30, "width": 30}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) - self.assertEqual(image_processor.size, {"shortest_edge": 42}) - self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + + +@require_torch +@require_vision +class PoolFormerImageProcessingNoCropPctTest(PoolFormerImageProcessingTest): + def setUp(self): + super().setUp() + self.image_processor_tester = PoolFormerImageProcessingTester(self, crop_pct=None)