From 23231d2e3597af60f18fba740140da9d06960e55 Mon Sep 17 00:00:00 2001 From: Mikhail Moskovchenko Date: Fri, 27 Jun 2025 20:50:21 +0400 Subject: [PATCH] Add `reduce_labels` to Mobilevit fast processor --- .../image_processing_mobilenet_v2_fast.py | 39 ++++---- .../image_processing_mobilevit_fast.py | 88 ++++++++++++++++--- .../test_image_processing_mobilenet_v2.py | 66 ++++---------- .../test_image_processing_mobilevit.py | 32 +++---- 4 files changed, 132 insertions(+), 93 deletions(-) diff --git a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py index 4edb08bb3d5..be01f33c791 100644 --- a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +++ b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py @@ -103,31 +103,38 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast): do_rescale: bool, do_center_crop: bool, do_normalize: bool, - size: SizeDict, + size: Optional[SizeDict], interpolation: Optional["F.InterpolationMode"], - rescale_factor: float, - crop_size: SizeDict, + rescale_factor: Optional[float], + crop_size: Optional[SizeDict], image_mean: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]], + disable_grouping: bool, return_tensors: Optional[Union[str, TensorType]], **kwargs, ) -> BatchFeature: + processed_images = [] + if do_reduce_labels: images = self.reduce_label(images) - # Group images by size for batched resizing - grouped_images, grouped_images_index = group_images_by_shape(images) + # Group images by shape for more efficient batch processing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) resized_images_grouped = {} + + # Process each group of images with the same shape for shape, stacked_images in grouped_images.items(): if do_resize: stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) resized_images_grouped[shape] = stacked_images + + # Reorder images to original sequence resized_images = reorder_images(resized_images_grouped, grouped_images_index) - # Group images by size for further processing - # Needed in case do_resize is False, or resize returns images with different sizes - grouped_images, grouped_images_index = group_images_by_shape(resized_images) + # Group again after resizing (in case resize produced different sizes) + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): if do_center_crop: stacked_images = self.center_crop(stacked_images, crop_size) @@ -138,7 +145,10 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast): processed_images_grouped[shape] = stacked_images processed_images = reorder_images(processed_images_grouped, grouped_images_index) + + # Stack all processed images if return_tensors is specified processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + return processed_images def _preprocess_images( @@ -170,12 +180,7 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast): kwargs["do_normalize"] = False kwargs["do_rescale"] = False - kwargs["interpolation"] = ( - pil_torch_interpolation_mapping[PILImageResampling.NEAREST] - if PILImageResampling.NEAREST in pil_torch_interpolation_mapping - else kwargs.get("interpolation") - ) - kwargs["input_data_format"] = ChannelDimension.FIRST + kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST] processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs) processed_segmentation_maps = processed_segmentation_maps.squeeze(1) @@ -233,15 +238,15 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast): images=images, **kwargs, ) - data = {"pixel_values": images} if segmentation_maps is not None: segmentation_maps = self._preprocess_segmentation_maps( segmentation_maps=segmentation_maps, **kwargs, ) - data["labels"] = segmentation_maps - return BatchFeature(data=data) + return BatchFeature(data={"pixel_values": images, "labels": segmentation_maps}) + + return BatchFeature(data={"pixel_values": images}) # Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.post_process_semantic_segmentation with Beit->MobileNetV2 def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py b/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py index 251666c8012..d727e9a30e3 100644 --- a/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py +++ b/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py @@ -14,9 +14,7 @@ # limitations under the License. """Fast Image processor class for MobileViT.""" -from typing import Optional - -import torch +from typing import Optional, Union from ...image_processing_utils import BatchFeature from ...image_processing_utils_fast import ( @@ -27,23 +25,46 @@ from ...image_processing_utils_fast import ( ) from ...image_utils import ( ChannelDimension, + ImageInput, PILImageResampling, + SizeDict, is_torch_tensor, make_list_of_images, pil_torch_interpolation_mapping, validate_kwargs, ) from ...processing_utils import Unpack -from ...utils import auto_docstring +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F class MobileVitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): """ do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`): Whether to flip the color channels from RGB to BGR or vice versa. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. """ do_flip_channel_order: Optional[bool] + do_reduce_labels: Optional[bool] @auto_docstring @@ -58,28 +79,44 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast): do_normalize = None do_convert_rgb = None do_flip_channel_order = True + do_reduce_labels = False valid_kwargs = MobileVitFastImageProcessorKwargs def __init__(self, **kwargs: Unpack[MobileVitFastImageProcessorKwargs]): super().__init__(**kwargs) + # Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.reduce_label + def reduce_label(self, labels: list["torch.Tensor"]): + for idx in range(len(labels)): + label = labels[idx] + label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label) + label = label - 1 + label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label) + labels[idx] = label + + return label + def _preprocess( self, - images, + images: list["torch.Tensor"], + do_reduce_labels: bool, do_resize: bool, - size: Optional[dict], - interpolation: Optional[str], + size: Optional[SizeDict], + interpolation: Optional["F.InterpolationMode"], do_rescale: bool, rescale_factor: Optional[float], do_center_crop: bool, - crop_size: Optional[dict], + crop_size: Optional[SizeDict], do_flip_channel_order: bool, disable_grouping: bool, - return_tensors: Optional[str], + return_tensors: Optional[Union[str, TensorType]], **kwargs, - ): + ) -> BatchFeature: processed_images = [] + if do_reduce_labels: + images = self.reduce_label(images) + # Group images by shape for more efficient batch processing grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) resized_images_grouped = {} @@ -119,6 +156,16 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast): return processed_images + def _preprocess_images( + self, + images, + **kwargs, + ): + """Preprocesses images.""" + kwargs["do_reduce_labels"] = False + processed_images = self._preprocess(images=images, **kwargs) + return processed_images + def _preprocess_segmentation_maps( self, segmentation_maps, @@ -149,8 +196,8 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast): @auto_docstring def preprocess( self, - images, - segmentation_maps=None, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, **kwargs: Unpack[MobileVitFastImageProcessorKwargs], ) -> BatchFeature: r""" @@ -192,7 +239,7 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast): kwargs.pop("default_to_square") kwargs.pop("data_format") - images = self._preprocess( + images = self._preprocess_images( images=images, **kwargs, ) @@ -207,6 +254,21 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast): return BatchFeature(data={"pixel_values": images}) def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): + """ + Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`MobileNetV2ForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`list[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ logits = outputs.logits # Resize logits and compute semantic segmentation maps diff --git a/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py b/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py index 5d4a2437f3c..7027a0b77a3 100644 --- a/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py +++ b/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py @@ -15,6 +15,7 @@ import unittest +import requests from datasets import load_dataset from transformers.testing_utils import require_torch, require_vision @@ -89,23 +90,14 @@ class MobileNetV2ImageProcessingTester: def prepare_semantic_single_inputs(): - dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True) - - image = Image.open(dataset[0]["file"]) - map = Image.open(dataset[1]["file"]) - - return image, map + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + example = ds[0] + return example["image"], example["map"] def prepare_semantic_batch_inputs(): - dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True) - - image1 = Image.open(dataset[0]["file"]) - map1 = Image.open(dataset[1]["file"]) - image2 = Image.open(dataset[2]["file"]) - map2 = Image.open(dataset[3]["file"]) - - return [image1, image2], [map1, map2] + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + return list(ds["image"][:2]), list(ds["map"][:2]) @require_torch @@ -275,41 +267,21 @@ class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase if self.image_processing_class is None or self.fast_image_processing_class is None: self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") - dummy_image, dummy_map = prepare_semantic_single_inputs() - + # Test with single image + dummy_image = Image.open( + requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw + ) image_processor_slow = self.image_processing_class(**self.image_processor_dict) image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) - image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") - image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") + encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") + self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) - self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1)) - self.assertLessEqual( - torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3 - ) - self.assertTrue(torch.allclose(image_encoding_slow.labels, image_encoding_fast.labels, atol=1e-1)) + # Test with single image and segmentation map + image, segmentation_map = prepare_semantic_single_inputs() - def test_slow_fast_equivalence_batched(self): - if not self.test_slow_image_processor or not self.test_fast_image_processor: - self.skipTest(reason="Skipping slow/fast equivalence test") - - if self.image_processing_class is None or self.fast_image_processing_class is None: - self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") - - if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: - self.skipTest( - reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" - ) - - dummy_images, dummy_maps = prepare_semantic_batch_inputs() - - image_processor_slow = self.image_processing_class(**self.image_processor_dict) - image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) - - encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt") - encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt") - - self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) - self.assertLessEqual( - torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 - ) + encoding_slow = image_processor_slow(image, segmentation_map, return_tensors="pt") + encoding_fast = image_processor_fast(image, segmentation_map, return_tensors="pt") + self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) + torch.testing.assert_close(encoding_slow.labels, encoding_fast.labels, atol=1e-1, rtol=1e-3) diff --git a/tests/models/mobilevit/test_image_processing_mobilevit.py b/tests/models/mobilevit/test_image_processing_mobilevit.py index 9cf49fca7c9..a09c2824ca0 100644 --- a/tests/models/mobilevit/test_image_processing_mobilevit.py +++ b/tests/models/mobilevit/test_image_processing_mobilevit.py @@ -248,6 +248,22 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): self.assertTrue(encoding["labels"].min().item() >= 0) self.assertTrue(encoding["labels"].max().item() <= 255) + def test_reduce_labels(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + + # ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150 + image, map = prepare_semantic_single_inputs() + encoding = image_processing(image, map, return_tensors="pt") + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 150) + + image_processing.do_reduce_labels = True + encoding = image_processing(image, map, return_tensors="pt") + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + @require_vision @require_torch def test_slow_fast_equivalence(self): @@ -275,19 +291,3 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): encoding_fast = image_processor_fast(image, segmentation_map, return_tensors="pt") self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) torch.testing.assert_close(encoding_slow.labels, encoding_fast.labels, atol=1e-1, rtol=1e-3) - - def test_reduce_labels(self): - for image_processing_class in self.image_processor_list: - # Initialize image_processing - image_processing = self.image_processing_class(**self.image_processor_dict) - - # ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150 - image, map = prepare_semantic_single_inputs() - encoding = image_processing(image, map, return_tensors="pt") - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 150) - - image_processing.do_reduce_labels = True - encoding = image_processing(image, map, return_tensors="pt") - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255)