diff --git a/docs/source/en/model_doc/mobilenet_v2.md b/docs/source/en/model_doc/mobilenet_v2.md index 1ba55a8e87a..5ddc4f0ea3f 100644 --- a/docs/source/en/model_doc/mobilenet_v2.md +++ b/docs/source/en/model_doc/mobilenet_v2.md @@ -114,6 +114,7 @@ print(f"The predicted class label is: {predicted_class_label}") [[autodoc]] MobileNetV2ImageProcessor - preprocess + - post_process_semantic_segmentation ## MobileNetV2ImageProcessorFast diff --git a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py index c8887ab836b..41d765ea314 100644 --- a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +++ b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py @@ -88,6 +88,11 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is + used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The + background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the + `preprocess` method. """ model_input_names = ["pixel_values"] @@ -104,6 +109,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): do_normalize: bool = True, image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, + do_reduce_labels: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -121,6 +127,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.do_reduce_labels = do_reduce_labels # Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize def resize( @@ -172,10 +179,151 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): **kwargs, ) + # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label + def reduce_label(self, label: ImageInput) -> np.ndarray: + label = to_numpy_array(label) + # Avoid using underflow conversion + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + return label + + def __call__(self, images, segmentation_maps=None, **kwargs): + """ + Preprocesses a batch of images and optionally segmentation maps. + + Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be + passed in as positional arguments. + """ + return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) + + def _preprocess( + self, + image: ImageInput, + do_reduce_labels: bool, + do_resize: bool, + do_rescale: bool, + do_center_crop: bool, + do_normalize: bool, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = None, + rescale_factor: Optional[float] = None, + crop_size: Optional[dict[str, int]] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + if do_reduce_labels: + image = self.reduce_label(image) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + return image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + # All transformations expect numpy arrays. + image = to_numpy_array(image) + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + image = self._preprocess( + image=image, + do_reduce_labels=False, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + input_data_format=input_data_format, + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + return image + + def _preprocess_mask( + self, + segmentation_map: ImageInput, + do_reduce_labels: Optional[bool] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[dict[str, int]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single mask.""" + segmentation_map = to_numpy_array(segmentation_map) + # Add channel dimension if missing - needed for certain transformations + if segmentation_map.ndim == 2: + added_channel_dim = True + segmentation_map = segmentation_map[None, ...] + input_data_format = ChannelDimension.FIRST + else: + added_channel_dim = False + if input_data_format is None: + input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1) + + segmentation_map = self._preprocess( + image=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, + size=size, + resample=PILImageResampling.NEAREST, + do_rescale=False, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_normalize=False, + image_mean=None, + image_std=None, + input_data_format=input_data_format, + ) + # Remove extra channel dimension if added for processing + if added_channel_dim: + segmentation_map = segmentation_map.squeeze(0) + segmentation_map = segmentation_map.astype(np.int64) + return segmentation_map + @filter_out_non_signature_kwargs() def preprocess( self, images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, do_resize: Optional[bool] = None, size: Optional[dict[str, int]] = None, resample: PILImageResampling = None, @@ -186,6 +334,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): do_normalize: Optional[bool] = None, image_mean: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, list[float]]] = None, + do_reduce_labels: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -197,6 +346,8 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): images (`ImageInput`): Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + segmentation_maps (`ImageInput`, *optional*): + Segmentation map to preprocess. do_resize (`bool`, *optional*, defaults to `self.do_resize`): Whether to resize the image. size (`dict[str, int]`, *optional*, defaults to `self.size`): @@ -219,6 +370,10 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): Image mean to use if `do_normalize` is set to `True`. image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): Image standard deviation to use if `do_normalize` is set to `True`. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. @@ -241,6 +396,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size size = get_size_dict(size, default_to_square=False) + do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels resample = resample if resample is not None else self.resample do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop crop_size = crop_size if crop_size is not None else self.crop_size @@ -253,11 +409,21 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): images = make_list_of_images(images) + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) + if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) + + if segmentation_maps is not None and not valid_images(segmentation_maps): + raise ValueError( + "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, @@ -270,42 +436,43 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): size=size, resample=resample, ) - # All transformations expect numpy arrays. - images = [to_numpy_array(image) for image in images] - if do_rescale and is_scaled_image(images[0]): - logger.warning_once( - "It looks like you are trying to rescale already rescaled images. If the input" - " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." - ) - - if input_data_format is None: - # We assume that all images have the same channel dimension format. - input_data_format = infer_channel_dimension_format(images[0]) - - all_images = [] - for image in images: - if do_resize: - image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) - - if do_center_crop: - image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) - - if do_rescale: - image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) - - if do_normalize: - image = self.normalize( - image=image, mean=image_mean, std=image_std, input_data_format=input_data_format - ) - - all_images.append(image) images = [ - to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) - for image in all_images + self._preprocess_image( + image=img, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in images ] data = {"pixel_values": images} + + if segmentation_maps is not None: + segmentation_maps = [ + self._preprocess_mask( + segmentation_map=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, + size=size, + do_center_crop=do_center_crop, + crop_size=crop_size, + input_data_format=input_data_format, + ) + for segmentation_map in segmentation_maps + ] + data["labels"] = segmentation_maps + return BatchFeature(data=data, tensor_type=return_tensors) # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->MobileNetV2 diff --git a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py index 59eb43917d4..be01f33c791 100644 --- a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +++ b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py @@ -14,16 +14,57 @@ # limitations under the License. """Fast Image processor class for MobileNetV2.""" -from typing import Optional +from typing import Optional, Union -from ...image_processing_utils_fast import BaseImageProcessorFast -from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, PILImageResampling -from ...utils import auto_docstring, is_torch_available, is_torch_tensor +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + is_torch_tensor, + make_list_of_images, + pil_torch_interpolation_mapping, + validate_kwargs, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) if is_torch_available(): import torch +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class MobileNetV2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + """ + + do_reduce_labels: Optional[bool] + @auto_docstring class MobileNetV2ImageProcessorFast(BaseImageProcessorFast): @@ -37,8 +78,177 @@ class MobileNetV2ImageProcessorFast(BaseImageProcessorFast): do_center_crop = True do_rescale = True do_normalize = True - do_convert_rgb = None + do_reduce_labels = False + valid_kwargs = MobileNetV2FastImageProcessorKwargs + def __init__(self, **kwargs: Unpack[MobileNetV2FastImageProcessorKwargs]): + super().__init__(**kwargs) + + # Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.reduce_label + def reduce_label(self, labels: list["torch.Tensor"]): + for idx in range(len(labels)): + label = labels[idx] + label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label) + label = label - 1 + label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label) + labels[idx] = label + + return label + + def _preprocess( + self, + images: list["torch.Tensor"], + do_reduce_labels: bool, + do_resize: bool, + do_rescale: bool, + do_center_crop: bool, + do_normalize: bool, + size: Optional[SizeDict], + interpolation: Optional["F.InterpolationMode"], + rescale_factor: Optional[float], + crop_size: Optional[SizeDict], + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + disable_grouping: bool, + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + processed_images = [] + + if do_reduce_labels: + images = self.reduce_label(images) + + # Group images by shape for more efficient batch processing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + + # Process each group of images with the same shape + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + + # Reorder images to original sequence + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group again after resizing (in case resize produced different sizes) + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + + # Stack all processed images if return_tensors is specified + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return processed_images + + def _preprocess_images( + self, + images, + **kwargs, + ): + """Preprocesses images.""" + kwargs["do_reduce_labels"] = False + processed_images = self._preprocess(images=images, **kwargs) + return processed_images + + def _preprocess_segmentation_maps( + self, + segmentation_maps, + **kwargs, + ): + """Preprocesses segmentation maps.""" + processed_segmentation_maps = [] + for segmentation_map in segmentation_maps: + segmentation_map = self._process_image( + segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST + ) + + if segmentation_map.ndim == 2: + segmentation_map = segmentation_map[None, ...] + + processed_segmentation_maps.append(segmentation_map) + + kwargs["do_normalize"] = False + kwargs["do_rescale"] = False + kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST] + processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs) + + processed_segmentation_maps = processed_segmentation_maps.squeeze(1) + + processed_segmentation_maps = processed_segmentation_maps.to(torch.int64) + return processed_segmentation_maps + + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + **kwargs: Unpack[MobileNetV2FastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess. + """ + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) + # Set default kwargs from self. This ensures that if a kwarg is not provided + # by the user, it gets its default value from the instance, or is set to None. + for kwarg_name in self.valid_kwargs.__annotations__: + kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + + # Extract parameters that are only used for preparing the input images + do_convert_rgb = kwargs.pop("do_convert_rgb") + input_data_format = kwargs.pop("input_data_format") + device = kwargs.pop("device") + # Prepare input images + images = self._prepare_input_images( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + + # Prepare segmentation maps + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2) + + # Update kwargs that need further processing before being validated + kwargs = self._further_process_kwargs(**kwargs) + + # Validate kwargs + self._validate_preprocess_kwargs(**kwargs) + + # torch resize uses interpolation instead of resample + resample = kwargs.pop("resample") + kwargs["interpolation"] = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample + ) + + # Pop kwargs that are not needed in _preprocess + kwargs.pop("default_to_square") + kwargs.pop("data_format") + + images = self._preprocess_images( + images=images, + **kwargs, + ) + + if segmentation_maps is not None: + segmentation_maps = self._preprocess_segmentation_maps( + segmentation_maps=segmentation_maps, + **kwargs, + ) + return BatchFeature(data={"pixel_values": images, "labels": segmentation_maps}) + + return BatchFeature(data={"pixel_values": images}) + + # Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.post_process_semantic_segmentation with Beit->MobileNetV2 def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit.py b/src/transformers/models/mobilevit/image_processing_mobilevit.py index b9eb353a654..86707613764 100644 --- a/src/transformers/models/mobilevit/image_processing_mobilevit.py +++ b/src/transformers/models/mobilevit/image_processing_mobilevit.py @@ -83,6 +83,11 @@ class MobileViTImageProcessor(BaseImageProcessor): do_flip_channel_order (`bool`, *optional*, defaults to `True`): Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order` parameter in the `preprocess` method. + do_reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is + used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The + background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the + `preprocess` method. """ model_input_names = ["pixel_values"] @@ -97,6 +102,7 @@ class MobileViTImageProcessor(BaseImageProcessor): do_center_crop: bool = True, crop_size: Optional[dict[str, int]] = None, do_flip_channel_order: bool = True, + do_reduce_labels: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -113,6 +119,7 @@ class MobileViTImageProcessor(BaseImageProcessor): self.do_center_crop = do_center_crop self.crop_size = crop_size self.do_flip_channel_order = do_flip_channel_order + self.do_reduce_labels = do_reduce_labels # Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize with PILImageResampling.BICUBIC->PILImageResampling.BILINEAR def resize( @@ -183,6 +190,15 @@ class MobileViTImageProcessor(BaseImageProcessor): """ return flip_channel_order(image, data_format=data_format, input_data_format=input_data_format) + # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label + def reduce_label(self, label: ImageInput) -> np.ndarray: + label = to_numpy_array(label) + # Avoid using underflow conversion + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + return label + def __call__(self, images, segmentation_maps=None, **kwargs): """ Preprocesses a batch of images and optionally segmentation maps. @@ -195,6 +211,7 @@ class MobileViTImageProcessor(BaseImageProcessor): def _preprocess( self, image: ImageInput, + do_reduce_labels: bool, do_resize: bool, do_rescale: bool, do_center_crop: bool, @@ -205,6 +222,9 @@ class MobileViTImageProcessor(BaseImageProcessor): crop_size: Optional[dict[str, int]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ): + if do_reduce_labels: + image = self.reduce_label(image) + if do_resize: image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) @@ -246,6 +266,7 @@ class MobileViTImageProcessor(BaseImageProcessor): image = self._preprocess( image=image, + do_reduce_labels=False, do_resize=do_resize, size=size, resample=resample, @@ -264,6 +285,7 @@ class MobileViTImageProcessor(BaseImageProcessor): def _preprocess_mask( self, segmentation_map: ImageInput, + do_reduce_labels: Optional[bool] = None, do_resize: Optional[bool] = None, size: Optional[dict[str, int]] = None, do_center_crop: Optional[bool] = None, @@ -284,6 +306,7 @@ class MobileViTImageProcessor(BaseImageProcessor): segmentation_map = self._preprocess( image=segmentation_map, + do_reduce_labels=do_reduce_labels, do_resize=do_resize, size=size, resample=PILImageResampling.NEAREST, @@ -312,6 +335,7 @@ class MobileViTImageProcessor(BaseImageProcessor): do_center_crop: Optional[bool] = None, crop_size: Optional[dict[str, int]] = None, do_flip_channel_order: Optional[bool] = None, + do_reduce_labels: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -342,6 +366,10 @@ class MobileViTImageProcessor(BaseImageProcessor): Size of the center crop if `do_center_crop` is set to `True`. do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`): Whether to flip the channel order of the image. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. return_tensors (`str` or `TensorType`, *optional*): The type of tensors to return. Can be one of: - Unset: Return a list of `np.ndarray`. @@ -374,6 +402,8 @@ class MobileViTImageProcessor(BaseImageProcessor): crop_size = crop_size if crop_size is not None else self.crop_size crop_size = get_size_dict(crop_size, param_name="crop_size") + do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels + images = make_list_of_images(images) if segmentation_maps is not None: @@ -426,6 +456,7 @@ class MobileViTImageProcessor(BaseImageProcessor): segmentation_maps = [ self._preprocess_mask( segmentation_map=segmentation_map, + do_reduce_labels=do_reduce_labels, do_resize=do_resize, size=size, do_center_crop=do_center_crop, diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py b/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py index 251666c8012..d727e9a30e3 100644 --- a/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py +++ b/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py @@ -14,9 +14,7 @@ # limitations under the License. """Fast Image processor class for MobileViT.""" -from typing import Optional - -import torch +from typing import Optional, Union from ...image_processing_utils import BatchFeature from ...image_processing_utils_fast import ( @@ -27,23 +25,46 @@ from ...image_processing_utils_fast import ( ) from ...image_utils import ( ChannelDimension, + ImageInput, PILImageResampling, + SizeDict, is_torch_tensor, make_list_of_images, pil_torch_interpolation_mapping, validate_kwargs, ) from ...processing_utils import Unpack -from ...utils import auto_docstring +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F class MobileVitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): """ do_flip_channel_order (`bool`, *optional*, defaults to `self.do_flip_channel_order`): Whether to flip the color channels from RGB to BGR or vice versa. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. """ do_flip_channel_order: Optional[bool] + do_reduce_labels: Optional[bool] @auto_docstring @@ -58,28 +79,44 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast): do_normalize = None do_convert_rgb = None do_flip_channel_order = True + do_reduce_labels = False valid_kwargs = MobileVitFastImageProcessorKwargs def __init__(self, **kwargs: Unpack[MobileVitFastImageProcessorKwargs]): super().__init__(**kwargs) + # Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.reduce_label + def reduce_label(self, labels: list["torch.Tensor"]): + for idx in range(len(labels)): + label = labels[idx] + label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label) + label = label - 1 + label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label) + labels[idx] = label + + return label + def _preprocess( self, - images, + images: list["torch.Tensor"], + do_reduce_labels: bool, do_resize: bool, - size: Optional[dict], - interpolation: Optional[str], + size: Optional[SizeDict], + interpolation: Optional["F.InterpolationMode"], do_rescale: bool, rescale_factor: Optional[float], do_center_crop: bool, - crop_size: Optional[dict], + crop_size: Optional[SizeDict], do_flip_channel_order: bool, disable_grouping: bool, - return_tensors: Optional[str], + return_tensors: Optional[Union[str, TensorType]], **kwargs, - ): + ) -> BatchFeature: processed_images = [] + if do_reduce_labels: + images = self.reduce_label(images) + # Group images by shape for more efficient batch processing grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) resized_images_grouped = {} @@ -119,6 +156,16 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast): return processed_images + def _preprocess_images( + self, + images, + **kwargs, + ): + """Preprocesses images.""" + kwargs["do_reduce_labels"] = False + processed_images = self._preprocess(images=images, **kwargs) + return processed_images + def _preprocess_segmentation_maps( self, segmentation_maps, @@ -149,8 +196,8 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast): @auto_docstring def preprocess( self, - images, - segmentation_maps=None, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, **kwargs: Unpack[MobileVitFastImageProcessorKwargs], ) -> BatchFeature: r""" @@ -192,7 +239,7 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast): kwargs.pop("default_to_square") kwargs.pop("data_format") - images = self._preprocess( + images = self._preprocess_images( images=images, **kwargs, ) @@ -207,6 +254,21 @@ class MobileViTImageProcessorFast(BaseImageProcessorFast): return BatchFeature(data={"pixel_values": images}) def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): + """ + Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`MobileNetV2ForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`list[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ logits = outputs.logits # Resize logits and compute semantic segmentation maps diff --git a/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py b/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py index 526fe04738b..7027a0b77a3 100644 --- a/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py +++ b/tests/models/mobilenet_v2/test_image_processing_mobilenet_v2.py @@ -15,13 +15,21 @@ import unittest +import requests +from datasets import load_dataset + from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torchvision_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs +if is_torch_available(): + import torch + if is_vision_available(): + from PIL import Image + from transformers import MobileNetV2ImageProcessor if is_torchvision_available(): @@ -41,6 +49,7 @@ class MobileNetV2ImageProcessingTester: size=None, do_center_crop=True, crop_size=None, + do_reduce_labels=False, ): size = size if size is not None else {"shortest_edge": 20} crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18} @@ -54,6 +63,7 @@ class MobileNetV2ImageProcessingTester: self.size = size self.do_center_crop = do_center_crop self.crop_size = crop_size + self.do_reduce_labels = do_reduce_labels def prepare_image_processor_dict(self): return { @@ -61,6 +71,7 @@ class MobileNetV2ImageProcessingTester: "size": self.size, "do_center_crop": self.do_center_crop, "crop_size": self.crop_size, + "do_reduce_labels": self.do_reduce_labels, } def expected_output_image_shape(self, images): @@ -78,6 +89,17 @@ class MobileNetV2ImageProcessingTester: ) +def prepare_semantic_single_inputs(): + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + example = ds[0] + return example["image"], example["map"] + + +def prepare_semantic_batch_inputs(): + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + return list(ds["image"][:2]), list(ds["map"][:2]) + + @require_torch @require_vision class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): @@ -99,13 +121,167 @@ class MobileNetV2ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase self.assertTrue(hasattr(image_processor, "size")) self.assertTrue(hasattr(image_processor, "do_center_crop")) self.assertTrue(hasattr(image_processor, "crop_size")) + self.assertTrue(hasattr(image_processor, "do_reduce_labels")) def test_image_processor_from_dict_with_kwargs(self): for image_processing_class in self.image_processor_list: image_processor = image_processing_class.from_dict(self.image_processor_dict) self.assertEqual(image_processor.size, {"shortest_edge": 20}) self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + self.assertEqual(image_processor.do_reduce_labels, False) - image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + image_processor = image_processing_class.from_dict( + self.image_processor_dict, size=42, crop_size=84, do_reduce_labels=True + ) self.assertEqual(image_processor.size, {"shortest_edge": 42}) self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + self.assertEqual(image_processor.do_reduce_labels, True) + + def test_call_segmentation_maps(self): + # Initialize image_processing + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + maps = [] + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + maps.append(torch.zeros(image.shape[-2:]).long()) + + # Test not batched input + encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Test batched + encoding = image_processing(image_inputs, maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Test not batched input (PIL images) + image, segmentation_map = prepare_semantic_single_inputs() + + encoding = image_processing(image, segmentation_map, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + # Test batched input (PIL images) + images, segmentation_maps = prepare_semantic_batch_inputs() + + encoding = image_processing(images, segmentation_maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 2, + self.image_processor_tester.num_channels, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 2, + self.image_processor_tester.crop_size["height"], + self.image_processor_tester.crop_size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + def test_reduce_labels(self): + # Initialize image_processing + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + + # ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150 + image, map = prepare_semantic_single_inputs() + encoding = image_processing(image, map, return_tensors="pt") + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 150) + + image_processing.do_reduce_labels = True + encoding = image_processing(image, map, return_tensors="pt") + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + + def test_slow_fast_equivalence(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + # Test with single image + dummy_image = Image.open( + requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw + ) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_image, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_image, return_tensors="pt") + self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) + + # Test with single image and segmentation map + image, segmentation_map = prepare_semantic_single_inputs() + + encoding_slow = image_processor_slow(image, segmentation_map, return_tensors="pt") + encoding_fast = image_processor_fast(image, segmentation_map, return_tensors="pt") + self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) + torch.testing.assert_close(encoding_slow.labels, encoding_fast.labels, atol=1e-1, rtol=1e-3) diff --git a/tests/models/mobilevit/test_image_processing_mobilevit.py b/tests/models/mobilevit/test_image_processing_mobilevit.py index df5caa6b7fb..a09c2824ca0 100644 --- a/tests/models/mobilevit/test_image_processing_mobilevit.py +++ b/tests/models/mobilevit/test_image_processing_mobilevit.py @@ -50,6 +50,7 @@ class MobileViTImageProcessingTester: do_center_crop=True, crop_size=None, do_flip_channel_order=True, + do_reduce_labels=False, ): size = size if size is not None else {"shortest_edge": 20} crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18} @@ -64,6 +65,7 @@ class MobileViTImageProcessingTester: self.do_center_crop = do_center_crop self.crop_size = crop_size self.do_flip_channel_order = do_flip_channel_order + self.do_reduce_labels = do_reduce_labels def prepare_image_processor_dict(self): return { @@ -72,6 +74,7 @@ class MobileViTImageProcessingTester: "do_center_crop": self.do_center_crop, "crop_size": self.crop_size, "do_flip_channel_order": self.do_flip_channel_order, + "do_reduce_labels": self.do_reduce_labels, } def expected_output_image_shape(self, images): @@ -122,16 +125,21 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): self.assertTrue(hasattr(image_processing, "do_center_crop")) self.assertTrue(hasattr(image_processing, "center_crop")) self.assertTrue(hasattr(image_processing, "do_flip_channel_order")) + self.assertTrue(hasattr(image_processing, "do_reduce_labels")) def test_image_processor_from_dict_with_kwargs(self): for image_processing_class in self.image_processor_list: image_processor = self.image_processing_class.from_dict(self.image_processor_dict) self.assertEqual(image_processor.size, {"shortest_edge": 20}) self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + self.assertEqual(image_processor.do_reduce_labels, False) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + image_processor = self.image_processing_class.from_dict( + self.image_processor_dict, size=42, crop_size=84, do_reduce_labels=True + ) self.assertEqual(image_processor.size, {"shortest_edge": 42}) self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + self.assertEqual(image_processor.do_reduce_labels, True) def test_call_segmentation_maps(self): for image_processing_class in self.image_processor_list: @@ -240,6 +248,22 @@ class MobileViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): self.assertTrue(encoding["labels"].min().item() >= 0) self.assertTrue(encoding["labels"].max().item() <= 255) + def test_reduce_labels(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + + # ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150 + image, map = prepare_semantic_single_inputs() + encoding = image_processing(image, map, return_tensors="pt") + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 150) + + image_processing.do_reduce_labels = True + encoding = image_processing(image, map, return_tensors="pt") + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + @require_vision @require_torch def test_slow_fast_equivalence(self):