This commit is contained in:
Mikhail Moskovchenko 2025-07-02 23:02:18 +02:00 committed by GitHub
commit a9c5d3962e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 723 additions and 52 deletions

View File

@ -114,6 +114,7 @@ print(f"The predicted class label is: {predicted_class_label}")
[[autodoc]] MobileNetV2ImageProcessor [[autodoc]] MobileNetV2ImageProcessor
- preprocess - preprocess
- post_process_semantic_segmentation
## MobileNetV2ImageProcessorFast ## MobileNetV2ImageProcessorFast

View File

@ -88,6 +88,11 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_reduce_labels (`bool`, *optional*, defaults to `False`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
`preprocess` method.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
@ -104,6 +109,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
do_normalize: bool = True, do_normalize: bool = True,
image_mean: Optional[Union[float, list[float]]] = None, image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None,
do_reduce_labels: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
@ -121,6 +127,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_reduce_labels = do_reduce_labels
# Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize # Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize
def resize( def resize(
@ -172,10 +179,151 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
**kwargs, **kwargs,
) )
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label
def reduce_label(self, label: ImageInput) -> np.ndarray:
label = to_numpy_array(label)
# Avoid using underflow conversion
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
return label
def __call__(self, images, segmentation_maps=None, **kwargs):
"""
Preprocesses a batch of images and optionally segmentation maps.
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
passed in as positional arguments.
"""
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
def _preprocess(
self,
image: ImageInput,
do_reduce_labels: bool,
do_resize: bool,
do_rescale: bool,
do_center_crop: bool,
do_normalize: bool,
size: Optional[dict[str, int]] = None,
resample: PILImageResampling = None,
rescale_factor: Optional[float] = None,
crop_size: Optional[dict[str, int]] = None,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
if do_reduce_labels:
image = self.reduce_label(image)
if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_center_crop:
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
return image
def _preprocess_image(
self,
image: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[dict[str, int]] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
# All transformations expect numpy arrays.
image = to_numpy_array(image)
if do_rescale and is_scaled_image(image):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
image = self._preprocess(
image=image,
do_reduce_labels=False,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
input_data_format=input_data_format,
)
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image
def _preprocess_mask(
self,
segmentation_map: ImageInput,
do_reduce_labels: Optional[bool] = None,
do_resize: Optional[bool] = None,
size: Optional[dict[str, int]] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[dict[str, int]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map)
# Add channel dimension if missing - needed for certain transformations
if segmentation_map.ndim == 2:
added_channel_dim = True
segmentation_map = segmentation_map[None, ...]
input_data_format = ChannelDimension.FIRST
else:
added_channel_dim = False
if input_data_format is None:
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
segmentation_map = self._preprocess(
image=segmentation_map,
do_reduce_labels=do_reduce_labels,
do_resize=do_resize,
size=size,
resample=PILImageResampling.NEAREST,
do_rescale=False,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_normalize=False,
image_mean=None,
image_std=None,
input_data_format=input_data_format,
)
# Remove extra channel dimension if added for processing
if added_channel_dim:
segmentation_map = segmentation_map.squeeze(0)
segmentation_map = segmentation_map.astype(np.int64)
return segmentation_map
@filter_out_non_signature_kwargs() @filter_out_non_signature_kwargs()
def preprocess( def preprocess(
self, self,
images: ImageInput, images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
do_resize: Optional[bool] = None, do_resize: Optional[bool] = None,
size: Optional[dict[str, int]] = None, size: Optional[dict[str, int]] = None,
resample: PILImageResampling = None, resample: PILImageResampling = None,
@ -186,6 +334,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
do_normalize: Optional[bool] = None, do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, list[float]]] = None, image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None,
do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
@ -197,6 +346,8 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
images (`ImageInput`): images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
segmentation_maps (`ImageInput`, *optional*):
Segmentation map to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`): do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image. Whether to resize the image.
size (`dict[str, int]`, *optional*, defaults to `self.size`): size (`dict[str, int]`, *optional*, defaults to `self.size`):
@ -219,6 +370,10 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
Image mean to use if `do_normalize` is set to `True`. Image mean to use if `do_normalize` is set to `True`.
image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use if `do_normalize` is set to `True`. Image standard deviation to use if `do_normalize` is set to `True`.
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g.
ADE20k). The background label will be replaced by 255.
return_tensors (`str` or `TensorType`, *optional*): return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of: The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`. - Unset: Return a list of `np.ndarray`.
@ -241,6 +396,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
@ -253,11 +409,21 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
images = make_list_of_images(images) images = make_list_of_images(images)
if segmentation_maps is not None:
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
if not valid_images(images): if not valid_images(images):
raise ValueError( raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray." "torch.Tensor, tf.Tensor or jax.ndarray."
) )
if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments( validate_preprocess_arguments(
do_rescale=do_rescale, do_rescale=do_rescale,
rescale_factor=rescale_factor, rescale_factor=rescale_factor,
@ -270,42 +436,43 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
size=size, size=size,
resample=resample, resample=resample,
) )
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_rescale and is_scaled_image(images[0]):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
all_images = []
for image in images:
if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_center_crop:
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)
all_images.append(image)
images = [ images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) self._preprocess_image(
for image in all_images image=img,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
input_data_format=input_data_format,
)
for img in images
] ]
data = {"pixel_values": images} data = {"pixel_values": images}
if segmentation_maps is not None:
segmentation_maps = [
self._preprocess_mask(
segmentation_map=segmentation_map,
do_reduce_labels=do_reduce_labels,
do_resize=do_resize,
size=size,
do_center_crop=do_center_crop,
crop_size=crop_size,
input_data_format=input_data_format,
)
for segmentation_map in segmentation_maps
]
data["labels"] = segmentation_maps
return BatchFeature(data=data, tensor_type=return_tensors) return BatchFeature(data=data, tensor_type=return_tensors)
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileNetV2 # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileNetV2

View File

@ -14,16 +14,57 @@
# limitations under the License. # limitations under the License.
"""Fast Image processor class for MobileNetV2.""" """Fast Image processor class for MobileNetV2."""
from typing import Optional from typing import Optional, Union
from ...image_processing_utils_fast import BaseImageProcessorFast from ...image_processing_utils import BatchFeature
from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling from ...image_processing_utils_fast import (
from ...utils import auto_docstring, is_torch_available, is_torch_tensor BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
is_torch_tensor,
make_list_of_images,
pil_torch_interpolation_mapping,
validate_kwargs,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
auto_docstring,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)
if is_torch_available(): if is_torch_available():
import torch import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
class MobileNetV2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
"""
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g.
ADE20k). The background label will be replaced by 255.
"""
do_reduce_labels: Optional[bool]
@auto_docstring @auto_docstring
class MobileNetV2ImageProcessorFast(BaseImageProcessorFast): class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
@ -37,8 +78,177 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast):
do_center_crop = True do_center_crop = True
do_rescale = True do_rescale = True
do_normalize = True do_normalize = True
do_convert_rgb = None do_reduce_labels = False
valid_kwargs = MobileNetV2FastImageProcessorKwargs
def __init__(self, **kwargs: Unpack[MobileNetV2FastImageProcessorKwargs]):
super().__init__(**kwargs)
# Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.reduce_label
def reduce_label(self, labels: list["torch.Tensor"]):
for idx in range(len(labels)):
label = labels[idx]
label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label)
label = label - 1
label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label)
labels[idx] = label
return label
def _preprocess(
self,
images: list["torch.Tensor"],
do_reduce_labels: bool,
do_resize: bool,
do_rescale: bool,
do_center_crop: bool,
do_normalize: bool,
size: Optional[SizeDict],
interpolation: Optional["F.InterpolationMode"],
rescale_factor: Optional[float],
crop_size: Optional[SizeDict],
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: bool,
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
processed_images = []
if do_reduce_labels:
images = self.reduce_label(images)
# Group images by shape for more efficient batch processing
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
# Process each group of images with the same shape
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
# Reorder images to original sequence
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group again after resizing (in case resize produced different sizes)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
stacked_images = self.center_crop(stacked_images, crop_size)
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
# Stack all processed images if return_tensors is specified
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return processed_images
def _preprocess_images(
self,
images,
**kwargs,
):
"""Preprocesses images."""
kwargs["do_reduce_labels"] = False
processed_images = self._preprocess(images=images, **kwargs)
return processed_images
def _preprocess_segmentation_maps(
self,
segmentation_maps,
**kwargs,
):
"""Preprocesses segmentation maps."""
processed_segmentation_maps = []
for segmentation_map in segmentation_maps:
segmentation_map = self._process_image(
segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST
)
if segmentation_map.ndim == 2:
segmentation_map = segmentation_map[None, ...]
processed_segmentation_maps.append(segmentation_map)
kwargs["do_normalize"] = False
kwargs["do_rescale"] = False
kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST]
processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs)
processed_segmentation_maps = processed_segmentation_maps.squeeze(1)
processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
return processed_segmentation_maps
@auto_docstring
def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
**kwargs: Unpack[MobileNetV2FastImageProcessorKwargs],
) -> BatchFeature:
r"""
segmentation_maps (`ImageInput`, *optional*):
The segmentation maps to preprocess.
"""
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# Prepare segmentation maps
if segmentation_maps is not None:
segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2)
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)
# torch resize uses interpolation instead of resample
resample = kwargs.pop("resample")
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
)
# Pop kwargs that are not needed in _preprocess
kwargs.pop("default_to_square")
kwargs.pop("data_format")
images = self._preprocess_images(
images=images,
**kwargs,
)
if segmentation_maps is not None:
segmentation_maps = self._preprocess_segmentation_maps(
segmentation_maps=segmentation_maps,
**kwargs,
)
return BatchFeature(data={"pixel_values": images, "labels": segmentation_maps})
return BatchFeature(data={"pixel_values": images})
# Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.post_process_semantic_segmentation with Beit->MobileNetV2
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
""" """
Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.

View File

@ -83,6 +83,11 @@ class MobileViTImageProcessor(BaseImageProcessor):
do_flip_channel_order (`bool`, *optional*, defaults to `True`): do_flip_channel_order (`bool`, *optional*, defaults to `True`):
Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order` Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order`
parameter in the `preprocess` method. parameter in the `preprocess` method.
do_reduce_labels (`bool`, *optional*, defaults to `False`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
`preprocess` method.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
@ -97,6 +102,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
do_center_crop: bool = True, do_center_crop: bool = True,
crop_size: Optional[dict[str, int]] = None, crop_size: Optional[dict[str, int]] = None,
do_flip_channel_order: bool = True, do_flip_channel_order: bool = True,
do_reduce_labels: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
@ -113,6 +119,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
self.do_center_crop = do_center_crop self.do_center_crop = do_center_crop
self.crop_size = crop_size self.crop_size = crop_size
self.do_flip_channel_order = do_flip_channel_order self.do_flip_channel_order = do_flip_channel_order
self.do_reduce_labels = do_reduce_labels
# Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize with PILImageResampling.BICUBIC->PILImageResampling.BILINEAR # Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize with PILImageResampling.BICUBIC->PILImageResampling.BILINEAR
def resize( def resize(
@ -183,6 +190,15 @@ class MobileViTImageProcessor(BaseImageProcessor):
""" """
return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format) return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format)
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label
def reduce_label(self, label: ImageInput) -> np.ndarray:
label = to_numpy_array(label)
# Avoid using underflow conversion
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
return label
def __call__(self, images, segmentation_maps=None, **kwargs): def __call__(self, images, segmentation_maps=None, **kwargs):
""" """
Preprocesses a batch of images and optionally segmentation maps. Preprocesses a batch of images and optionally segmentation maps.
@ -195,6 +211,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
def _preprocess( def _preprocess(
self, self,
image: ImageInput, image: ImageInput,
do_reduce_labels: bool,
do_resize: bool, do_resize: bool,
do_rescale: bool, do_rescale: bool,
do_center_crop: bool, do_center_crop: bool,
@ -205,6 +222,9 @@ class MobileViTImageProcessor(BaseImageProcessor):
crop_size: Optional[dict[str, int]] = None, crop_size: Optional[dict[str, int]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
): ):
if do_reduce_labels:
image = self.reduce_label(image)
if do_resize: if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
@ -246,6 +266,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
image = self._preprocess( image = self._preprocess(
image=image, image=image,
do_reduce_labels=False,
do_resize=do_resize, do_resize=do_resize,
size=size, size=size,
resample=resample, resample=resample,
@ -264,6 +285,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
def _preprocess_mask( def _preprocess_mask(
self, self,
segmentation_map: ImageInput, segmentation_map: ImageInput,
do_reduce_labels: Optional[bool] = None,
do_resize: Optional[bool] = None, do_resize: Optional[bool] = None,
size: Optional[dict[str, int]] = None, size: Optional[dict[str, int]] = None,
do_center_crop: Optional[bool] = None, do_center_crop: Optional[bool] = None,
@ -284,6 +306,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
segmentation_map = self._preprocess( segmentation_map = self._preprocess(
image=segmentation_map, image=segmentation_map,
do_reduce_labels=do_reduce_labels,
do_resize=do_resize, do_resize=do_resize,
size=size, size=size,
resample=PILImageResampling.NEAREST, resample=PILImageResampling.NEAREST,
@ -312,6 +335,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
do_center_crop: Optional[bool] = None, do_center_crop: Optional[bool] = None,
crop_size: Optional[dict[str, int]] = None, crop_size: Optional[dict[str, int]] = None,
do_flip_channel_order: Optional[bool] = None, do_flip_channel_order: Optional[bool] = None,
do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
@ -342,6 +366,10 @@ class MobileViTImageProcessor(BaseImageProcessor):
Size of the center crop if `do_center_crop` is set to `True`. Size of the center crop if `do_center_crop` is set to `True`.
do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`): do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`):
Whether to flip the channel order of the image. Whether to flip the channel order of the image.
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g.
ADE20k). The background label will be replaced by 255.
return_tensors (`str` or `TensorType`, *optional*): return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of: The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`. - Unset: Return a list of `np.ndarray`.
@ -374,6 +402,8 @@ class MobileViTImageProcessor(BaseImageProcessor):
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size, param_name="crop_size") crop_size = get_size_dict(crop_size, param_name="crop_size")
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
images = make_list_of_images(images) images = make_list_of_images(images)
if segmentation_maps is not None: if segmentation_maps is not None:
@ -426,6 +456,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
segmentation_maps = [ segmentation_maps = [
self._preprocess_mask( self._preprocess_mask(
segmentation_map=segmentation_map, segmentation_map=segmentation_map,
do_reduce_labels=do_reduce_labels,
do_resize=do_resize, do_resize=do_resize,
size=size, size=size,
do_center_crop=do_center_crop, do_center_crop=do_center_crop,

View File

@ -14,9 +14,7 @@
# limitations under the License. # limitations under the License.
"""Fast Image processor class for MobileViT.""" """Fast Image processor class for MobileViT."""
from typing import Optional from typing import Optional, Union
import torch
from ...image_processing_utils import BatchFeature from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import ( from ...image_processing_utils_fast import (
@ -27,23 +25,46 @@ from ...image_processing_utils_fast import (
) )
from ...image_utils import ( from ...image_utils import (
ChannelDimension, ChannelDimension,
ImageInput,
PILImageResampling, PILImageResampling,
SizeDict,
is_torch_tensor, is_torch_tensor,
make_list_of_images, make_list_of_images,
pil_torch_interpolation_mapping, pil_torch_interpolation_mapping,
validate_kwargs, validate_kwargs,
) )
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import auto_docstring from ...utils import (
TensorType,
auto_docstring,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
class MobileVitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): class MobileVitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
""" """
do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`): do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`):
Whether to flip the color channels from RGB to BGR or vice versa. Whether to flip the color channels from RGB to BGR or vice versa.
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g.
ADE20k). The background label will be replaced by 255.
""" """
do_flip_channel_order: Optional[bool] do_flip_channel_order: Optional[bool]
do_reduce_labels: Optional[bool]
@auto_docstring @auto_docstring
@ -58,28 +79,44 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
do_normalize = None do_normalize = None
do_convert_rgb = None do_convert_rgb = None
do_flip_channel_order = True do_flip_channel_order = True
do_reduce_labels = False
valid_kwargs = MobileVitFastImageProcessorKwargs valid_kwargs = MobileVitFastImageProcessorKwargs
def __init__(self, **kwargs: Unpack[MobileVitFastImageProcessorKwargs]): def __init__(self, **kwargs: Unpack[MobileVitFastImageProcessorKwargs]):
super().__init__(**kwargs) super().__init__(**kwargs)
# Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.reduce_label
def reduce_label(self, labels: list["torch.Tensor"]):
for idx in range(len(labels)):
label = labels[idx]
label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label)
label = label - 1
label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label)
labels[idx] = label
return label
def _preprocess( def _preprocess(
self, self,
images, images: list["torch.Tensor"],
do_reduce_labels: bool,
do_resize: bool, do_resize: bool,
size: Optional[dict], size: Optional[SizeDict],
interpolation: Optional[str], interpolation: Optional["F.InterpolationMode"],
do_rescale: bool, do_rescale: bool,
rescale_factor: Optional[float], rescale_factor: Optional[float],
do_center_crop: bool, do_center_crop: bool,
crop_size: Optional[dict], crop_size: Optional[SizeDict],
do_flip_channel_order: bool, do_flip_channel_order: bool,
disable_grouping: bool, disable_grouping: bool,
return_tensors: Optional[str], return_tensors: Optional[Union[str, TensorType]],
**kwargs, **kwargs,
): ) -> BatchFeature:
processed_images = [] processed_images = []
if do_reduce_labels:
images = self.reduce_label(images)
# Group images by shape for more efficient batch processing # Group images by shape for more efficient batch processing
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {} resized_images_grouped = {}
@ -119,6 +156,16 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
return processed_images return processed_images
def _preprocess_images(
self,
images,
**kwargs,
):
"""Preprocesses images."""
kwargs["do_reduce_labels"] = False
processed_images = self._preprocess(images=images, **kwargs)
return processed_images
def _preprocess_segmentation_maps( def _preprocess_segmentation_maps(
self, self,
segmentation_maps, segmentation_maps,
@ -149,8 +196,8 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
@auto_docstring @auto_docstring
def preprocess( def preprocess(
self, self,
images, images: ImageInput,
segmentation_maps=None, segmentation_maps: Optional[ImageInput] = None,
**kwargs: Unpack[MobileVitFastImageProcessorKwargs], **kwargs: Unpack[MobileVitFastImageProcessorKwargs],
) -> BatchFeature: ) -> BatchFeature:
r""" r"""
@ -192,7 +239,7 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
kwargs.pop("default_to_square") kwargs.pop("default_to_square")
kwargs.pop("data_format") kwargs.pop("data_format")
images = self._preprocess( images = self._preprocess_images(
images=images, images=images,
**kwargs, **kwargs,
) )
@ -207,6 +254,21 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast):
return BatchFeature(data={"pixel_values": images}) return BatchFeature(data={"pixel_values": images})
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
"""
Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`MobileNetV2ForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`list[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
predictions will not be resized.
Returns:
semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits logits = outputs.logits
# Resize logits and compute semantic segmentation maps # Resize logits and compute semantic segmentation maps

View File

@ -15,13 +15,21 @@
import unittest import unittest
import requests
from datasets import load_dataset
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torchvision_available, is_vision_available from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available(): if is_vision_available():
from PIL import Image
from transformers import MobileNetV2ImageProcessor from transformers import MobileNetV2ImageProcessor
if is_torchvision_available(): if is_torchvision_available():
@ -41,6 +49,7 @@ class MobileNetV2ImageProcessingTester:
size=None, size=None,
do_center_crop=True, do_center_crop=True,
crop_size=None, crop_size=None,
do_reduce_labels=False,
): ):
size = size if size is not None else {"shortest_edge": 20} size = size if size is not None else {"shortest_edge": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18} crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
@ -54,6 +63,7 @@ class MobileNetV2ImageProcessingTester:
self.size = size self.size = size
self.do_center_crop = do_center_crop self.do_center_crop = do_center_crop
self.crop_size = crop_size self.crop_size = crop_size
self.do_reduce_labels = do_reduce_labels
def prepare_image_processor_dict(self): def prepare_image_processor_dict(self):
return { return {
@ -61,6 +71,7 @@ class MobileNetV2ImageProcessingTester:
"size": self.size, "size": self.size,
"do_center_crop": self.do_center_crop, "do_center_crop": self.do_center_crop,
"crop_size": self.crop_size, "crop_size": self.crop_size,
"do_reduce_labels": self.do_reduce_labels,
} }
def expected_output_image_shape(self, images): def expected_output_image_shape(self, images):
@ -78,6 +89,17 @@ class MobileNetV2ImageProcessingTester:
) )
def prepare_semantic_single_inputs():
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
example = ds[0]
return example["image"], example["map"]
def prepare_semantic_batch_inputs():
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
return list(ds["image"][:2]), list(ds["map"][:2])
@require_torch @require_torch
@require_vision @require_vision
class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
@ -99,13 +121,167 @@ class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase
self.assertTrue(hasattr(image_processor, "size")) self.assertTrue(hasattr(image_processor, "size"))
self.assertTrue(hasattr(image_processor, "do_center_crop")) self.assertTrue(hasattr(image_processor, "do_center_crop"))
self.assertTrue(hasattr(image_processor, "crop_size")) self.assertTrue(hasattr(image_processor, "crop_size"))
self.assertTrue(hasattr(image_processor, "do_reduce_labels"))
def test_image_processor_from_dict_with_kwargs(self): def test_image_processor_from_dict_with_kwargs(self):
for image_processing_class in self.image_processor_list: for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict) image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20}) self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
self.assertEqual(image_processor.do_reduce_labels, False)
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) image_processor = image_processing_class.from_dict(
self.image_processor_dict, size=42, crop_size=84, do_reduce_labels=True
)
self.assertEqual(image_processor.size, {"shortest_edge": 42}) self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
self.assertEqual(image_processor.do_reduce_labels, True)
def test_call_segmentation_maps(self):
# Initialize image_processing
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
maps = []
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
maps.append(torch.zeros(image.shape[-2:]).long())
# Test not batched input
encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
1,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test batched
encoding = image_processing(image_inputs, maps, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test not batched input (PIL images)
image, segmentation_map = prepare_semantic_single_inputs()
encoding = image_processing(image, segmentation_map, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
1,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
1,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
# Test batched input (PIL images)
images, segmentation_maps = prepare_semantic_batch_inputs()
encoding = image_processing(images, segmentation_maps, return_tensors="pt")
self.assertEqual(
encoding["pixel_values"].shape,
(
2,
self.image_processor_tester.num_channels,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(
encoding["labels"].shape,
(
2,
self.image_processor_tester.crop_size["height"],
self.image_processor_tester.crop_size["width"],
),
)
self.assertEqual(encoding["labels"].dtype, torch.long)
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
def test_reduce_labels(self):
# Initialize image_processing
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
image, map = prepare_semantic_single_inputs()
encoding = image_processing(image, map, return_tensors="pt")
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 150)
image_processing.do_reduce_labels = True
encoding = image_processing(image, map, return_tensors="pt")
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
def test_slow_fast_equivalence(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
# Test with single image
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(dummy_image, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_image, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
# Test with single image and segmentation map
image, segmentation_map = prepare_semantic_single_inputs()
encoding_slow = image_processor_slow(image, segmentation_map, return_tensors="pt")
encoding_fast = image_processor_fast(image, segmentation_map, return_tensors="pt")
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)
torch.testing.assert_close(encoding_slow.labels, encoding_fast.labels, atol=1e-1, rtol=1e-3)

View File

@ -50,6 +50,7 @@ class MobileViTImageProcessingTester:
do_center_crop=True, do_center_crop=True,
crop_size=None, crop_size=None,
do_flip_channel_order=True, do_flip_channel_order=True,
do_reduce_labels=False,
): ):
size = size if size is not None else {"shortest_edge": 20} size = size if size is not None else {"shortest_edge": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18} crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
@ -64,6 +65,7 @@ class MobileViTImageProcessingTester:
self.do_center_crop = do_center_crop self.do_center_crop = do_center_crop
self.crop_size = crop_size self.crop_size = crop_size
self.do_flip_channel_order = do_flip_channel_order self.do_flip_channel_order = do_flip_channel_order
self.do_reduce_labels = do_reduce_labels
def prepare_image_processor_dict(self): def prepare_image_processor_dict(self):
return { return {
@ -72,6 +74,7 @@ class MobileViTImageProcessingTester:
"do_center_crop": self.do_center_crop, "do_center_crop": self.do_center_crop,
"crop_size": self.crop_size, "crop_size": self.crop_size,
"do_flip_channel_order": self.do_flip_channel_order, "do_flip_channel_order": self.do_flip_channel_order,
"do_reduce_labels": self.do_reduce_labels,
} }
def expected_output_image_shape(self, images): def expected_output_image_shape(self, images):
@ -122,16 +125,21 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self.assertTrue(hasattr(image_processing, "do_center_crop")) self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "center_crop")) self.assertTrue(hasattr(image_processing, "center_crop"))
self.assertTrue(hasattr(image_processing, "do_flip_channel_order")) self.assertTrue(hasattr(image_processing, "do_flip_channel_order"))
self.assertTrue(hasattr(image_processing, "do_reduce_labels"))
def test_image_processor_from_dict_with_kwargs(self): def test_image_processor_from_dict_with_kwargs(self):
for image_processing_class in self.image_processor_list: for image_processing_class in self.image_processor_list:
image_processor = self.image_processing_class.from_dict(self.image_processor_dict) image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 20}) self.assertEqual(image_processor.size, {"shortest_edge": 20})
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
self.assertEqual(image_processor.do_reduce_labels, False)
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size=42, crop_size=84, do_reduce_labels=True
)
self.assertEqual(image_processor.size, {"shortest_edge": 42}) self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
self.assertEqual(image_processor.do_reduce_labels, True)
def test_call_segmentation_maps(self): def test_call_segmentation_maps(self):
for image_processing_class in self.image_processor_list: for image_processing_class in self.image_processor_list:
@ -240,6 +248,22 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self.assertTrue(encoding["labels"].min().item() >= 0) self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255) self.assertTrue(encoding["labels"].max().item() <= 255)
def test_reduce_labels(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150
image, map = prepare_semantic_single_inputs()
encoding = image_processing(image, map, return_tensors="pt")
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 150)
image_processing.do_reduce_labels = True
encoding = image_processing(image, map, return_tensors="pt")
self.assertTrue(encoding["labels"].min().item() >= 0)
self.assertTrue(encoding["labels"].max().item() <= 255)
@require_vision @require_vision
@require_torch @require_torch
def test_slow_fast_equivalence(self): def test_slow_fast_equivalence(self):