Add reduce_labels to Mobilevit fast processor

This commit is contained in:
Mikhail Moskovchenko 2025-06-27 20:50:21 +04:00
parent d7ac282524
commit 23231d2e35
4 changed files with 132 additions and 93 deletions

View File

@ -103,31 +103,38 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
do_rescale: bool, do_rescale: bool,
do_center_crop: bool, do_center_crop: bool,
do_normalize: bool, do_normalize: bool,
size: SizeDict, size: Optional[SizeDict],
interpolation: Optional["F.InterpolationMode"], interpolation: Optional["F.InterpolationMode"],
rescale_factor: float, rescale_factor: Optional[float],
crop_size: SizeDict, crop_size: Optional[SizeDict],
image_mean: Optional[Union[float, list[float]]], image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]], image_std: Optional[Union[float, list[float]]],
disable_grouping: bool,
return_tensors: Optional[Union[str, TensorType]], return_tensors: Optional[Union[str, TensorType]],
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
processed_images = []
if do_reduce_labels: if do_reduce_labels:
images = self.reduce_label(images) images = self.reduce_label(images)
# Group images by size for batched resizing # Group images by shape for more efficient batch processing
grouped_images, grouped_images_index = group_images_by_shape(images) grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {} resized_images_grouped = {}
# Process each group of images with the same shape
for shape, stacked_images in grouped_images.items(): for shape, stacked_images in grouped_images.items():
if do_resize: if do_resize:
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images resized_images_grouped[shape] = stacked_images
# Reorder images to original sequence
resized_images = reorder_images(resized_images_grouped, grouped_images_index) resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for further processing # Group again after resizing (in case resize produced different sizes)
# Needed in case do_resize is False, or resize returns images with different sizes grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
processed_images_grouped = {} processed_images_grouped = {}
for shape, stacked_images in grouped_images.items(): for shape, stacked_images in grouped_images.items():
if do_center_crop: if do_center_crop:
stacked_images = self.center_crop(stacked_images, crop_size) stacked_images = self.center_crop(stacked_images, crop_size)
@ -138,7 +145,10 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
processed_images_grouped[shape] = stacked_images processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index) processed_images = reorder_images(processed_images_grouped, grouped_images_index)
# Stack all processed images if return_tensors is specified
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return processed_images return processed_images
def _preprocess_images( def _preprocess_images(
@ -170,12 +180,7 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
kwargs["do_normalize"] = False kwargs["do_normalize"] = False
kwargs["do_rescale"] = False kwargs["do_rescale"] = False
kwargs["interpolation"] = ( kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST]
pil_torch_interpolation_mapping[PILImageResampling.NEAREST]
if PILImageResampling.NEAREST in pil_torch_interpolation_mapping
else kwargs.get("interpolation")
)
kwargs["input_data_format"] = ChannelDimension.FIRST
processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs) processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs)
processed_segmentation_maps = processed_segmentation_maps.squeeze(1) processed_segmentation_maps = processed_segmentation_maps.squeeze(1)
@ -233,15 +238,15 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
images=images, images=images,
**kwargs, **kwargs,
) )
data = {"pixel_values": images}
if segmentation_maps is not None: if segmentation_maps is not None:
segmentation_maps = self._preprocess_segmentation_maps( segmentation_maps = self._preprocess_segmentation_maps(
segmentation_maps=segmentation_maps, segmentation_maps=segmentation_maps,
**kwargs, **kwargs,
) )
data["labels"] = segmentation_maps return BatchFeature(data={"pixel_values": images, "labels": segmentation_maps})
return BatchFeature(data=data)
return BatchFeature(data={"pixel_values": images})
# Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.post_process_semantic_segmentation with Beit->MobileNetV2 # Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.post_process_semantic_segmentation with Beit->MobileNetV2
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):

View File

@ -14,9 +14,7 @@
# limitations under the License. # limitations under the License.
"""Fast Image processor class for MobileViT.""" """Fast Image processor class for MobileViT."""
from typing import Optional from typing import Optional, Union
import torch
from ...image_processing_utils import BatchFeature from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import ( from ...image_processing_utils_fast import (
@ -27,23 +25,46 @@ from ...image_processing_utils_fast import (
) )
from ...image_utils import ( from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput,
PILImageResampling, PILImageResampling,
SizeDict,
is_torch_tensor, is_torch_tensor,
make_list_of_images, make_list_of_images,
pil_torch_interpolation_mapping, pil_torch_interpolation_mapping,
validate_kwargs, validate_kwargs,
) )
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import auto_docstring from ...utils import (
TensorType,
auto_docstring,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
class MobileVitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): class MobileVitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
""" """
do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`): do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`):
Whether to flip the color channels from RGB to BGR or vice versa. Whether to flip the color channels from RGB to BGR or vice versa.
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g.
ADE20k). The background label will be replaced by 255.
""" """
do_flip_channel_order: Optional[bool] do_flip_channel_order: Optional[bool]
do_reduce_labels: Optional[bool]
@auto_docstring @auto_docstring
@ -58,28 +79,44 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
do_normalize = None do_normalize = None
do_convert_rgb = None do_convert_rgb = None
do_flip_channel_order = True do_flip_channel_order = True
do_reduce_labels = False
valid_kwargs = MobileVitFastImageProcessorKwargs valid_kwargs = MobileVitFastImageProcessorKwargs
def __init__(self, **kwargs: Unpack[MobileVitFastImageProcessorKwargs]): def __init__(self, **kwargs: Unpack[MobileVitFastImageProcessorKwargs]):
super().__init__(**kwargs) super().__init__(**kwargs)
# Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.reduce_label
def reduce_label(self, labels: list["torch.Tensor"]):
for idx in range(len(labels)):
label = labels[idx]
label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label)
label = label - 1
label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label)
labels[idx] = label
return label
def _preprocess( def _preprocess(
self, self,
images, images: list["torch.Tensor"],
do_reduce_labels: bool,
do_resize: bool, do_resize: bool,
size: Optional[dict], size: Optional[SizeDict],
interpolation: Optional[str], interpolation: Optional["F.InterpolationMode"],
do_rescale: bool, do_rescale: bool,
rescale_factor: Optional[float], rescale_factor: Optional[float],
do_center_crop: bool, do_center_crop: bool,
crop_size: Optional[dict], crop_size: Optional[SizeDict],
do_flip_channel_order: bool, do_flip_channel_order: bool,
disable_grouping: bool, disable_grouping: bool,
return_tensors: Optional[str], return_tensors: Optional[Union[str, TensorType]],
**kwargs, **kwargs,
): ) -> BatchFeature:
processed_images = [] processed_images = []
if do_reduce_labels:
images = self.reduce_label(images)
# Group images by shape for more efficient batch processing # Group images by shape for more efficient batch processing
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {} resized_images_grouped = {}
@ -119,6 +156,16 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
return processed_images return processed_images
def _preprocess_images(
self,
images,
**kwargs,
):
"""Preprocesses images."""
kwargs["do_reduce_labels"] = False
processed_images = self._preprocess(images=images, **kwargs)
return processed_images
def _preprocess_segmentation_maps( def _preprocess_segmentation_maps(
self, self,
segmentation_maps, segmentation_maps,
@ -149,8 +196,8 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
@auto_docstring @auto_docstring
def preprocess( def preprocess(
self, self,
images, images: ImageInput,
segmentation_maps=None, segmentation_maps: Optional[ImageInput] = None,
**kwargs: Unpack[MobileVitFastImageProcessorKwargs], **kwargs: Unpack[MobileVitFastImageProcessorKwargs],
) -> BatchFeature: ) -> BatchFeature:
r""" r"""
@ -192,7 +239,7 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
kwargs.pop("default_to_square") kwargs.pop("default_to_square")
kwargs.pop("data_format") kwargs.pop("data_format")
images = self._preprocess( images = self._preprocess_images(
images=images, images=images,
**kwargs, **kwargs,
) )
@ -207,6 +254,21 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
return BatchFeature(data={"pixel_values": images}) return BatchFeature(data={"pixel_values": images})
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
"""
Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`MobileNetV2ForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`list[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
predictions will not be resized.
Returns:
semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits logits = outputs.logits
# Resize logits and compute semantic segmentation maps # Resize logits and compute semantic segmentation maps

View File

@ -15,6 +15,7 @@
import unittest import unittest
import requests
from datasets import load_dataset from datasets import load_dataset
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
@ -89,23 +90,14 @@ class MobileNetV2ImageProcessingTester:
def prepare_semantic_single_inputs(): def prepare_semantic_single_inputs():
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True) ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
example = ds[0]
image = Image.open(dataset[0]["file"]) return example["image"], example["map"]
map = Image.open(dataset[1]["file"])
return image, map
def prepare_semantic_batch_inputs(): def prepare_semantic_batch_inputs():
dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True) ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
return list(ds["image"][:2]), list(ds["map"][:2])
image1 = Image.open(dataset[0]["file"])
map1 = Image.open(dataset[1]["file"])
image2 = Image.open(dataset[2]["file"])
map2 = Image.open(dataset[3]["file"])
return [image1, image2], [map1, map2]
@require_torch @require_torch
@ -275,41 +267,21 @@ class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
if self.image_processing_class is None or self.fast_image_processing_class is None: if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
dummy_image, dummy_map = prepare_semantic_single_inputs() # Test with single image
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
image_processor_slow = self.image_processing_class(**self.image_processor_dict) image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1)) # Test with single image and segmentation map
self.assertLessEqual( image, segmentation_map = prepare_semantic_single_inputs()
torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3
)
self.assertTrue(torch.allclose(image_encoding_slow.labels, image_encoding_fast.labels, atol=1e-1))
def test_slow_fast_equivalence_batched(self): encoding_slow = image_processor_slow(image, segmentation_map, return_tensors="pt")
if not self.test_slow_image_processor or not self.test_fast_image_processor: encoding_fast = image_processor_fast(image, segmentation_map, return_tensors="pt")
self.skipTest(reason="Skipping slow/fast equivalence test") self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
torch.testing.assert_close(encoding_slow.labels, encoding_fast.labels, atol=1e-1, rtol=1e-3)
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)
dummy_images, dummy_maps = prepare_semantic_batch_inputs()
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt")
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)

View File

@ -248,6 +248,22 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self.assertTrue(encoding["labels"].min().item() >= 0) self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255) self.assertTrue(encoding["labels"].max().item() <= 255)
def test_reduce_labels(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
image, map = prepare_semantic_single_inputs()
encoding = image_processing(image, map, return_tensors="pt")
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 150)
image_processing.do_reduce_labels = True
encoding = image_processing(image, map, return_tensors="pt")
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
@require_vision @require_vision
@require_torch @require_torch
def test_slow_fast_equivalence(self): def test_slow_fast_equivalence(self):
@ -275,19 +291,3 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
encoding_fast = image_processor_fast(image, segmentation_map, return_tensors="pt") encoding_fast = image_processor_fast(image, segmentation_map, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
torch.testing.assert_close(encoding_slow.labels, encoding_fast.labels, atol=1e-1, rtol=1e-3) torch.testing.assert_close(encoding_slow.labels, encoding_fast.labels, atol=1e-1, rtol=1e-3)
def test_reduce_labels(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
image, map = prepare_semantic_single_inputs()
encoding = image_processing(image, map, return_tensors="pt")
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 150)
image_processing.do_reduce_labels = True
encoding = image_processing(image, map, return_tensors="pt")
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)