Add Fast Mobilenet-V2 Processor (#37113)

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
Parteek 2025-04-14 20:38:47 +05:30 committed by GitHub
parent 4774a39d05
commit a53a63c9c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 114 additions and 13 deletions

View File

@ -84,6 +84,11 @@ If you're interested in submitting a resource to be included here, please feel f
[[autodoc]] MobileNetV2ImageProcessor
- preprocess
## MobileNetV2ImageProcessorFast
[[autodoc]] MobileNetV2ImageProcessorFast
- preprocess
- post_process_semantic_segmentation
## MobileNetV2Model

View File

@ -116,7 +116,7 @@ else:
("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
("mllama", ("MllamaImageProcessor",)),
("mobilenet_v1", ("MobileNetV1ImageProcessor",)),
("mobilenet_v2", ("MobileNetV2ImageProcessor",)),
("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")),
("mobilevit", ("MobileViTImageProcessor",)),
("mobilevitv2", ("MobileViTImageProcessor",)),
("nat", ("ViTImageProcessor", "ViTImageProcessorFast")),

View File

@ -21,6 +21,7 @@ if TYPE_CHECKING:
from .configuration_mobilenet_v2 import *
from .feature_extraction_mobilenet_v2 import *
from .image_processing_mobilenet_v2 import *
from .image_processing_mobilenet_v2_fast import *
from .modeling_mobilenet_v2 import *
else:
import sys

View File

@ -0,0 +1,89 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for MobileNetV2."""
from typing import List, Tuple
from ...image_processing_utils_fast import BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, BaseImageProcessorFast
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling
from ...utils import add_start_docstrings, is_torch_available, is_torch_tensor
if is_torch_available():
import torch
@add_start_docstrings(
"Constructs a fast MobileNetV2 image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
)
class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"shortest_edge": 256}
default_to_square = False
crop_size = {"height": 224, "width": 224}
do_resize = True
do_center_crop = True
do_rescale = True
do_normalize = True
do_convert_rgb = None
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`MobileNetV2ForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
# TODO: add support for other frameworks
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
__all__ = ["MobileNetV2ImageProcessorFast"]

View File

@ -16,7 +16,7 @@
import unittest
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@ -24,6 +24,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
if is_vision_available():
from transformers import MobileNetV2ImageProcessor
if is_torchvision_available():
from transformers import MobileNetV2ImageProcessorFast
class MobileNetV2ImageProcessingTester:
def __init__(
@ -79,6 +82,7 @@ class MobileNetV2ImageProcessingTester:
@require_vision
class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = MobileNetV2ImageProcessor if is_vision_available() else None
fast_image_processing_class = MobileNetV2ImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@ -89,17 +93,19 @@ class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processor = self.image_processing_class(**self.image_processor_dict)
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processor, "do_resize"))
self.assertTrue(hasattr(image_processor, "size"))
self.assertTrue(hasattr(image_processor, "do_center_crop"))
self.assertTrue(hasattr(image_processor, "crop_size"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})