diff --git a/docs/source/en/model_doc/mobilenet_v2.md b/docs/source/en/model_doc/mobilenet_v2.md index b78a8eb72f6..ffe830ac8d9 100644 --- a/docs/source/en/model_doc/mobilenet_v2.md +++ b/docs/source/en/model_doc/mobilenet_v2.md @@ -84,6 +84,11 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] MobileNetV2ImageProcessor - preprocess + +## MobileNetV2ImageProcessorFast + +[[autodoc]] MobileNetV2ImageProcessorFast + - preprocess - post_process_semantic_segmentation ## MobileNetV2Model diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 5e8dae8326e..3389261575c 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -116,7 +116,7 @@ else: ("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")), ("mllama", ("MllamaImageProcessor",)), ("mobilenet_v1", ("MobileNetV1ImageProcessor",)), - ("mobilenet_v2", ("MobileNetV2ImageProcessor",)), + ("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")), ("mobilevit", ("MobileViTImageProcessor",)), ("mobilevitv2", ("MobileViTImageProcessor",)), ("nat", ("ViTImageProcessor", "ViTImageProcessorFast")), diff --git a/src/transformers/models/mobilenet_v2/__init__.py b/src/transformers/models/mobilenet_v2/__init__.py index c29b5fc245e..0a5dbc3ce4c 100644 --- a/src/transformers/models/mobilenet_v2/__init__.py +++ b/src/transformers/models/mobilenet_v2/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from .configuration_mobilenet_v2 import * from .feature_extraction_mobilenet_v2 import * from .image_processing_mobilenet_v2 import * + from .image_processing_mobilenet_v2_fast import * from .modeling_mobilenet_v2 import * else: import sys diff --git a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py new file mode 100644 index 00000000000..52421103b3a --- /dev/null +++ b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for MobileNetV2.""" + +from typing import List, Tuple + +from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast +from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling +from ...utils import add_start_docstrings, is_torch_available, is_torch_tensor + + +if is_torch_available(): + import torch + + +@add_start_docstrings( + "Constructs a fast MobileNetV2 image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, +) +class MobileNetV2ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"shortest_edge": 256} + default_to_square = False + crop_size = {"height": 224, "width": 224} + do_resize = True + do_center_crop = True + do_rescale = True + do_normalize = True + do_convert_rgb = None + + def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): + """ + Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`MobileNetV2ForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + # TODO: add support for other frameworks + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + +__all__ = ["MobileNetV2ImageProcessorFast"] diff --git a/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py b/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py index ed4331de168..526fe04738b 100644 --- a/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py +++ b/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py @@ -16,7 +16,7 @@ import unittest from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -24,6 +24,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im if is_vision_available(): from transformers import MobileNetV2ImageProcessor + if is_torchvision_available(): + from transformers import MobileNetV2ImageProcessorFast + class MobileNetV2ImageProcessingTester: def __init__( @@ -79,6 +82,7 @@ class MobileNetV2ImageProcessingTester: @require_vision class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = MobileNetV2ImageProcessor if is_vision_available() else None + fast_image_processing_class = MobileNetV2ImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -89,17 +93,19 @@ class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processor = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processor, "do_resize")) - self.assertTrue(hasattr(image_processor, "size")) - self.assertTrue(hasattr(image_processor, "do_center_crop")) - self.assertTrue(hasattr(image_processor, "crop_size")) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processor, "do_resize")) + self.assertTrue(hasattr(image_processor, "size")) + self.assertTrue(hasattr(image_processor, "do_center_crop")) + self.assertTrue(hasattr(image_processor, "crop_size")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"shortest_edge": 20}) - self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 20}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) - self.assertEqual(image_processor.size, {"shortest_edge": 42}) - self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})