From b922b22ec2e458978dbd89038ad4b47885b34195 Mon Sep 17 00:00:00 2001 From: Sam Rae Date: Wed, 18 Jun 2025 18:33:29 +0100 Subject: [PATCH] 36978 | Fast image processor for DPT model (#37481) * chore: ran codegen script * test: test_image_processor_properties * test: test_image_processor_from_dict_with_kwargs * test: wip - test_padding * test: test_padding * test: test_keep_aspect_ratio * wip * test * test: wip * test: wip * test: test_call_segmentation_maps, wip * chore: tidy up * test: test_call_segmentation_maps * fix: test_save_load_fast_slow * test: reduce labels * chore: make fixup * chore: rm comment * chore: tidy * chore remove comment * refactor: no need to infer channel dimesnion * refactor: encapsulate logic for preparing segmentation maps * refactor: improve readability of segmentation_map preparation * improvement: batched version of pad_image * chore: fixup * docs * chore: make quality * chore: remove unecessary comment * fix: add SemanticSegmentationMixin * feat: add post_process_depth_estimation to fast dpt image processor * chore: fix formatting * remove max_height, max_width * fix: better way of processin segmentation maps - copied from Beit Fast processor * chore: formatting + remove TODO * chore: fixup styles * chore: remove unecessary line break * chore: core review suggestion to remove autodocstring * fix: add do_reduce_labels logic + refactor - refactor preprocess logic to make it consistent with other processors - add missing reduce labels logic * refactor: remove deprecated mixin * chore: fixup * use modular for dpt + final nit changes * fix style --------- Co-authored-by: Samuel Rae Co-authored-by: yonigozlan --- docs/source/en/model_doc/dpt.md | 6 + .../models/auto/image_processing_auto.py | 4 +- .../models/beit/image_processing_beit_fast.py | 5 - src/transformers/models/dpt/__init__.py | 1 + .../models/dpt/image_processing_dpt_fast.py | 474 ++++++++++++++++++ src/transformers/models/dpt/modular_dpt.py | 313 ++++++++++++ tests/models/dpt/test_image_processing_dpt.py | 357 +++++++------ 7 files changed, 1010 insertions(+), 150 deletions(-) create mode 100644 src/transformers/models/dpt/image_processing_dpt_fast.py create mode 100644 src/transformers/models/dpt/modular_dpt.py diff --git a/docs/source/en/model_doc/dpt.md b/docs/source/en/model_doc/dpt.md index 16992079738..a763e2af62f 100644 --- a/docs/source/en/model_doc/dpt.md +++ b/docs/source/en/model_doc/dpt.md @@ -78,7 +78,13 @@ If you're interested in submitting a resource to be included here, please feel f [[autodoc]] DPTImageProcessor - preprocess + +## DPTImageProcessorFast + +[[autodoc]] DPTImageProcessorFast + - preprocess - post_process_semantic_segmentation + - post_process_depth_estimation ## DPTModel diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index a3feebced91..69087ec68f9 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -74,14 +74,14 @@ else: ("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")), ("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")), ("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")), - ("depth_anything", ("DPTImageProcessor",)), + ("depth_anything", ("DPTImageProcessor", "DPTImageProcessorFast")), ("depth_pro", ("DepthProImageProcessor", "DepthProImageProcessorFast")), ("deta", ("DetaImageProcessor",)), ("detr", ("DetrImageProcessor", "DetrImageProcessorFast")), ("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")), ("dinov2", ("BitImageProcessor", "BitImageProcessorFast")), ("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")), - ("dpt", ("DPTImageProcessor",)), + ("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")), ("efficientformer", ("EfficientFormerImageProcessor",)), ("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")), ("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")), diff --git a/src/transformers/models/beit/image_processing_beit_fast.py b/src/transformers/models/beit/image_processing_beit_fast.py index 6a4077008da..97eff4f5adf 100644 --- a/src/transformers/models/beit/image_processing_beit_fast.py +++ b/src/transformers/models/beit/image_processing_beit_fast.py @@ -174,11 +174,6 @@ class BeitImageProcessorFast(BaseImageProcessorFast): processed_segmentation_maps = processed_segmentation_maps.to(torch.int64) return processed_segmentation_maps - def __call__(self, images, segmentation_maps=None, **kwargs): - # Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both - # be passed in as positional arguments. - return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) - @auto_docstring def preprocess( self, diff --git a/src/transformers/models/dpt/__init__.py b/src/transformers/models/dpt/__init__.py index 086750423db..ce0070f270f 100644 --- a/src/transformers/models/dpt/__init__.py +++ b/src/transformers/models/dpt/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from .configuration_dpt import * from .feature_extraction_dpt import * from .image_processing_dpt import * + from .image_processing_dpt_fast import * from .modeling_dpt import * else: import sys diff --git a/src/transformers/models/dpt/image_processing_dpt_fast.py b/src/transformers/models/dpt/image_processing_dpt_fast.py new file mode 100644 index 00000000000..c9db9c51852 --- /dev/null +++ b/src/transformers/models/dpt/image_processing_dpt_fast.py @@ -0,0 +1,474 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/dpt/modular_dpt.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dpt.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Iterable +from typing import TYPE_CHECKING, Optional, Union + +from transformers.image_processing_base import BatchFeature +from transformers.image_transforms import group_images_by_shape, reorder_images + +from ...image_processing_utils_fast import BaseImageProcessorFast, DefaultFastImageProcessorKwargs +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + is_torch_tensor, + make_list_of_images, + pil_torch_interpolation_mapping, + validate_kwargs, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + requires_backends, +) + + +if TYPE_CHECKING: + from ...modeling_outputs import DepthEstimatorOutput + +if is_torch_available(): + import torch + +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F +elif is_torchvision_available(): + from torchvision.transforms import functional as F + + +class DPTFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + ensure_multiple_of (`int`, *optional*, defaults to 1): + If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden + by `ensure_multiple_of` in `preprocess`. + do_pad (`bool`, *optional*, defaults to `False`): + Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in + combination with DPT. + size_divisor (`int`, *optional*): + If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the + DINOv2 paper, which uses the model in combination with DPT. + keep_aspect_ratio (`bool`, *optional*, defaults to `False`): + If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can + be overidden by `keep_aspect_ratio` in `preprocess`. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + """ + + ensure_multiple_of: Optional[int] + size_divisor: Optional[int] + do_pad: Optional[bool] + keep_aspect_ratio: Optional[bool] + do_reduce_labels: Optional[bool] + + +def get_resize_output_image_size( + input_image: "torch.Tensor", + output_size: Union[int, Iterable[int]], + keep_aspect_ratio: bool, + multiple: int, +) -> SizeDict: + def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None): + x = round(val / multiple) * multiple + + if max_val is not None and x > max_val: + x = math.floor(val / multiple) * multiple + + if x < min_val: + x = math.ceil(val / multiple) * multiple + + return x + + input_height, input_width = input_image.shape[-2:] + output_height, output_width = output_size + + # determine new height and width + scale_height = output_height / input_height + scale_width = output_width / input_width + + if keep_aspect_ratio: + # scale as little as possible + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + + new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple) + new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple) + + return SizeDict(height=new_height, width=new_width) + + +@auto_docstring +class DPTImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 384, "width": 384} + default_to_square = True + crop_size = None + do_resize = True + do_center_crop = None + do_rescale = True + do_normalize = True + do_reduce_labels = None + + valid_kwargs = DPTFastImageProcessorKwargs + do_pad = False + rescale_factor = 1 / 255 + ensure_multiple_of = 1 + keep_aspect_ratio = False + + def __init__(self, **kwargs: Unpack[DPTFastImageProcessorKwargs]): + super().__init__(**kwargs) + + def reduce_label(self, labels: list["torch.Tensor"]): + for idx in range(len(labels)): + label = labels[idx] + label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label) + label = label - 1 + label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label) + labels[idx] = label + + return label + + def _preprocess( + self, + images: list["torch.Tensor"], + do_reduce_labels: bool, + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + keep_aspect_ratio: bool, + ensure_multiple_of: Optional[int], + do_pad: bool, + size_divisor: Optional[int], + **kwargs, + ) -> BatchFeature: + if do_reduce_labels: + images = self.reduce_label(images) + + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_images, + size=size, + interpolation=interpolation, + ensure_multiple_of=ensure_multiple_of, + keep_aspect_ratio=keep_aspect_ratio, + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + if do_pad: + stacked_images = self.pad_image(stacked_images, size_divisor) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + return processed_images + + def _preprocess_images( + self, + images, + **kwargs, + ): + """Preprocesses images.""" + kwargs["do_reduce_labels"] = False + processed_images = self._preprocess(images=images, **kwargs) + return processed_images + + def _preprocess_segmentation_maps( + self, + segmentation_maps, + **kwargs, + ): + """Preprocesses segmentation maps.""" + processed_segmentation_maps = [] + for segmentation_map in segmentation_maps: + segmentation_map = self._process_image( + segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST + ) + + if segmentation_map.ndim == 2: + segmentation_map = segmentation_map[None, ...] + + processed_segmentation_maps.append(segmentation_map) + + kwargs["do_normalize"] = False + kwargs["do_rescale"] = False + kwargs["input_data_format"] = ChannelDimension.FIRST + processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs) + + processed_segmentation_maps = processed_segmentation_maps.squeeze(1) + + processed_segmentation_maps = processed_segmentation_maps.to(torch.int64) + return processed_segmentation_maps + + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + **kwargs: Unpack[DPTFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess. + """ + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) + # Set default kwargs from self. This ensures that if a kwarg is not provided + # by the user, it gets its default value from the instance, or is set to None. + for kwarg_name in self.valid_kwargs.__annotations__: + kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + + # Extract parameters that are only used for preparing the input images + do_convert_rgb = kwargs.pop("do_convert_rgb") + input_data_format = kwargs.pop("input_data_format") + device = kwargs.pop("device") + # Prepare input images + images = self._prepare_input_images( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + + # Prepare segmentation maps + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2) + + # Update kwargs that need further processing before being validated + kwargs = self._further_process_kwargs(**kwargs) + + # Validate kwargs + self._validate_preprocess_kwargs(**kwargs) + + # torch resize uses interpolation instead of resample + resample = kwargs.pop("resample") + kwargs["interpolation"] = ( + pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample + ) + + # Pop kwargs that are not needed in _preprocess + kwargs.pop("default_to_square") + kwargs.pop("data_format") + + images = self._preprocess_images( + images=images, + **kwargs, + ) + data = {"pixel_values": images} + + if segmentation_maps is not None: + segmentation_maps = self._preprocess_segmentation_maps( + segmentation_maps=segmentation_maps, + **kwargs, + ) + data["labels"] = segmentation_maps + + return BatchFeature(data=data) + + def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): + """ + Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`DPTForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`list[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + # TODO: add support for other frameworks + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + def resize( + self, + image: "torch.Tensor", + size: SizeDict, + interpolation: "F.InterpolationMode" = None, + antialias: bool = True, + ensure_multiple_of: Optional[int] = 1, + keep_aspect_ratio: bool = False, + ) -> "torch.Tensor": + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`. + antialias (`bool`, *optional*, defaults to `True`): + Whether to use antialiasing when resizing the image + ensure_multiple_of (`int`, *optional*): + If `do_resize` is `True`, the image is resized to a size that is a multiple of this value + keep_aspect_ratio (`bool`, *optional*, defaults to `False`): + If `True`, and `do_resize` is `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. + + Returns: + `torch.Tensor`: The resized image. + """ + if not size.height or not size.width: + raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}") + + output_size = get_resize_output_image_size( + image, + output_size=(size.height, size.width), + keep_aspect_ratio=keep_aspect_ratio, + multiple=ensure_multiple_of, + ) + return super().resize(image, output_size, interpolation=interpolation, antialias=antialias) + + def pad_image( + self, + image: "torch.Tensor", + size_divisor: int = 1, + ) -> "torch.Tensor": + r""" + Center pad a batch of images to be a multiple of `size_divisor`. + + Args: + image (`torch.Tensor`): + Image to pad. Can be a batch of images of dimensions (N, C, H, W) or a single image of dimensions (C, H, W). + size_divisor (`int`): + The width and height of the image will be padded to a multiple of this number. + """ + height, width = image.shape[-2:] + + def _get_pad(size, size_divisor): + new_size = math.ceil(size / size_divisor) * size_divisor + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + pad_top, pad_bottom = _get_pad(height, size_divisor) + pad_left, pad_right = _get_pad(width, size_divisor) + padding = (pad_left, pad_top, pad_right, pad_bottom) + return F.pad(image, padding) + + def post_process_depth_estimation( + self, + outputs: "DepthEstimatorOutput", + target_sizes: Optional[Union[TensorType, list[tuple[int, int]], None]] = None, + ) -> list[dict[str, TensorType]]: + """ + Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images. + Only supports PyTorch. + + Args: + outputs ([`DepthEstimatorOutput`]): + Raw outputs of the model. + target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + (height, width) of each image in the batch. If left to None, predictions will not be resized. + + Returns: + `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth + predictions. + """ + requires_backends(self, "torch") + + predicted_depth = outputs.predicted_depth + + if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth" + ) + + results = [] + target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes + for depth, target_size in zip(predicted_depth, target_sizes): + if target_size is not None: + depth = torch.nn.functional.interpolate( + depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False + ).squeeze() + + results.append({"predicted_depth": depth}) + + return results + + +__all__ = ["DPTImageProcessorFast"] diff --git a/src/transformers/models/dpt/modular_dpt.py b/src/transformers/models/dpt/modular_dpt.py new file mode 100644 index 00000000000..43aeffb2608 --- /dev/null +++ b/src/transformers/models/dpt/modular_dpt.py @@ -0,0 +1,313 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Iterable +from typing import TYPE_CHECKING, Optional, Union + +from transformers.image_processing_base import BatchFeature +from transformers.image_transforms import group_images_by_shape, reorder_images +from transformers.models.beit.image_processing_beit_fast import BeitImageProcessorFast + +from ...image_processing_utils_fast import ( + DefaultFastImageProcessorKwargs, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + PILImageResampling, + SizeDict, +) +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + requires_backends, +) + + +if TYPE_CHECKING: + from ...modeling_outputs import DepthEstimatorOutput + +if is_torch_available(): + import torch + +if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F +elif is_torchvision_available(): + from torchvision.transforms import functional as F + + +def get_resize_output_image_size( + input_image: "torch.Tensor", + output_size: Union[int, Iterable[int]], + keep_aspect_ratio: bool, + multiple: int, +) -> SizeDict: + def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None): + x = round(val / multiple) * multiple + + if max_val is not None and x > max_val: + x = math.floor(val / multiple) * multiple + + if x < min_val: + x = math.ceil(val / multiple) * multiple + + return x + + input_height, input_width = input_image.shape[-2:] + output_height, output_width = output_size + + # determine new height and width + scale_height = output_height / input_height + scale_width = output_width / input_width + + if keep_aspect_ratio: + # scale as little as possible + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + + new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple) + new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple) + + return SizeDict(height=new_height, width=new_width) + + +class DPTFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + ensure_multiple_of (`int`, *optional*, defaults to 1): + If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden + by `ensure_multiple_of` in `preprocess`. + do_pad (`bool`, *optional*, defaults to `False`): + Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in + combination with DPT. + size_divisor (`int`, *optional*): + If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the + DINOv2 paper, which uses the model in combination with DPT. + keep_aspect_ratio (`bool`, *optional*, defaults to `False`): + If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can + be overidden by `keep_aspect_ratio` in `preprocess`. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + """ + + ensure_multiple_of: Optional[int] + size_divisor: Optional[int] + do_pad: Optional[bool] + keep_aspect_ratio: Optional[bool] + do_reduce_labels: Optional[bool] + + +@auto_docstring +class DPTImageProcessorFast(BeitImageProcessorFast): + resample = PILImageResampling.BICUBIC + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 384, "width": 384} + do_resize = True + do_rescale = True + do_normalize = True + do_pad = False + rescale_factor = 1 / 255 + ensure_multiple_of = 1 + keep_aspect_ratio = False + do_reduce_labels = False + crop_size = None + do_center_crop = None + do_reduce_labels = None + + valid_kwargs = DPTFastImageProcessorKwargs + + def from_dict(): + raise NotImplementedError("No need to override this method") + + def resize( + self, + image: "torch.Tensor", + size: SizeDict, + interpolation: "F.InterpolationMode" = None, + antialias: bool = True, + ensure_multiple_of: Optional[int] = 1, + keep_aspect_ratio: bool = False, + ) -> "torch.Tensor": + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`. + antialias (`bool`, *optional*, defaults to `True`): + Whether to use antialiasing when resizing the image + ensure_multiple_of (`int`, *optional*): + If `do_resize` is `True`, the image is resized to a size that is a multiple of this value + keep_aspect_ratio (`bool`, *optional*, defaults to `False`): + If `True`, and `do_resize` is `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. + + Returns: + `torch.Tensor`: The resized image. + """ + if not size.height or not size.width: + raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}") + + output_size = get_resize_output_image_size( + image, + output_size=(size.height, size.width), + keep_aspect_ratio=keep_aspect_ratio, + multiple=ensure_multiple_of, + ) + return BeitImageProcessorFast().resize(image, output_size, interpolation=interpolation, antialias=antialias) + + def pad_image( + self, + image: "torch.Tensor", + size_divisor: int = 1, + ) -> "torch.Tensor": + r""" + Center pad a batch of images to be a multiple of `size_divisor`. + + Args: + image (`torch.Tensor`): + Image to pad. Can be a batch of images of dimensions (N, C, H, W) or a single image of dimensions (C, H, W). + size_divisor (`int`): + The width and height of the image will be padded to a multiple of this number. + """ + height, width = image.shape[-2:] + + def _get_pad(size, size_divisor): + new_size = math.ceil(size / size_divisor) * size_divisor + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + pad_top, pad_bottom = _get_pad(height, size_divisor) + pad_left, pad_right = _get_pad(width, size_divisor) + padding = (pad_left, pad_top, pad_right, pad_bottom) + return F.pad(image, padding) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_reduce_labels: bool, + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + keep_aspect_ratio: bool, + ensure_multiple_of: Optional[int], + do_pad: bool, + size_divisor: Optional[int], + **kwargs, + ) -> BatchFeature: + if do_reduce_labels: + images = self.reduce_label(images) + + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize( + image=stacked_images, + size=size, + interpolation=interpolation, + ensure_multiple_of=ensure_multiple_of, + keep_aspect_ratio=keep_aspect_ratio, + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + if do_pad: + stacked_images = self.pad_image(stacked_images, size_divisor) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + return processed_images + + def post_process_depth_estimation( + self, + outputs: "DepthEstimatorOutput", + target_sizes: Optional[Union[TensorType, list[tuple[int, int]], None]] = None, + ) -> list[dict[str, TensorType]]: + """ + Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images. + Only supports PyTorch. + + Args: + outputs ([`DepthEstimatorOutput`]): + Raw outputs of the model. + target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + (height, width) of each image in the batch. If left to None, predictions will not be resized. + + Returns: + `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth + predictions. + """ + requires_backends(self, "torch") + + predicted_depth = outputs.predicted_depth + + if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth" + ) + + results = [] + target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes + for depth, target_size in zip(predicted_depth, target_sizes): + if target_size is not None: + depth = torch.nn.functional.interpolate( + depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False + ).squeeze() + + results.append({"predicted_depth": depth}) + + return results + + +__all__ = ["DPTImageProcessorFast"] diff --git a/tests/models/dpt/test_image_processing_dpt.py b/tests/models/dpt/test_image_processing_dpt.py index 8d5e8ea75ef..f0a80a6e14b 100644 --- a/tests/models/dpt/test_image_processing_dpt.py +++ b/tests/models/dpt/test_image_processing_dpt.py @@ -20,6 +20,7 @@ from datasets import load_dataset from transformers.file_utils import is_torch_available, is_vision_available from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torchvision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -32,6 +33,9 @@ if is_vision_available(): from transformers import DPTImageProcessor + if is_torchvision_available(): + from transformers import DPTImageProcessorFast + class DPTImageProcessingTester: def __init__( @@ -114,6 +118,7 @@ def prepare_semantic_batch_inputs(): @require_vision class DPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = DPTImageProcessor if is_vision_available() else None + fast_image_processing_class = DPTImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -124,170 +129,236 @@ class DPTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "image_mean")) - self.assertTrue(hasattr(image_processing, "image_std")) - self.assertTrue(hasattr(image_processing, "do_normalize")) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_rescale")) - self.assertTrue(hasattr(image_processing, "rescale_factor")) - self.assertTrue(hasattr(image_processing, "do_pad")) - self.assertTrue(hasattr(image_processing, "size_divisor")) - self.assertTrue(hasattr(image_processing, "do_reduce_labels")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "do_pad")) + self.assertTrue(hasattr(image_processing, "size_divisor")) + self.assertTrue(hasattr(image_processing, "do_reduce_labels")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 18, "width": 18}) + for image_processing_class in self.image_processor_list: + image_processing_class = image_processing_class(**self.image_processor_dict) + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 18, "width": 18}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42) - self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) def test_padding(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - image = np.random.randn(3, 249, 491) - - # test individual method - image = image_processing.pad_image(image, size_divisor=4) - self.assertTrue(image.shape[1] % 4 == 0) - self.assertTrue(image.shape[2] % 4 == 0) - - # test by calling - pixel_values = image_processing.preprocess( - image, do_rescale=False, do_resize=False, do_pad=True, size_divisor=4, return_tensors="pt" - ).pixel_values - self.assertTrue(pixel_values.shape[2] % 4 == 0) - self.assertTrue(pixel_values.shape[3] % 4 == 0) + for image_processing_class in self.image_processor_list: + if image_processing_class == DPTImageProcessorFast: + image = torch.arange(0, 366777, 1, dtype=torch.uint8).reshape(3, 249, 491) + image_processor = image_processing_class(**self.image_processor_dict) + padded_image = image_processor.pad_image(image, size_divisor=4) + self.assertTrue(padded_image.shape[1] % 4 == 0) + self.assertTrue(padded_image.shape[2] % 4 == 0) + pixel_values = image_processor.preprocess( + image, do_rescale=False, do_resize=False, do_pad=True, size_divisor=4, return_tensors="pt" + ).pixel_values + self.assertTrue(pixel_values.shape[2] % 4 == 0) + self.assertTrue(pixel_values.shape[3] % 4 == 0) + else: + image_processor = image_processing_class(**self.image_processor_dict) + image = np.random.randn(3, 249, 491) + image = image_processor.pad_image(image, size_divisor=4) + self.assertTrue(image.shape[1] % 4 == 0) + self.assertTrue(image.shape[2] % 4 == 0) + pixel_values = image_processor.preprocess( + image, do_rescale=False, do_resize=False, do_pad=True, size_divisor=4, return_tensors="pt" + ).pixel_values + self.assertTrue(pixel_values.shape[2] % 4 == 0) + self.assertTrue(pixel_values.shape[3] % 4 == 0) def test_keep_aspect_ratio(self): size = {"height": 512, "width": 512} - image_processor = DPTImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=32) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(size=size, keep_aspect_ratio=True, ensure_multiple_of=32) - image = np.zeros((489, 640, 3)) + image = np.zeros((489, 640, 3)) - pixel_values = image_processor(image, return_tensors="pt").pixel_values + pixel_values = image_processor(image, return_tensors="pt").pixel_values - self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672]) + self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672]) # Copied from transformers.tests.models.beit.test_image_processing_beit.BeitImageProcessingTest.test_call_segmentation_maps def test_call_segmentation_maps(self): - # Initialize image_processor - image_processor = self.image_processing_class(**self.image_processor_dict) - # create random PyTorch tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) - maps = [] - for image in image_inputs: - self.assertIsInstance(image, torch.Tensor) - maps.append(torch.zeros(image.shape[-2:]).long()) + for image_processing_class in self.image_processor_list: + # Initialize image_processor + image_processor = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + maps = [] + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + maps.append(torch.zeros(image.shape[-2:]).long()) - # Test not batched input - encoding = image_processor(image_inputs[0], maps[0], return_tensors="pt") - self.assertEqual( - encoding["pixel_values"].shape, - ( - 1, - self.image_processor_tester.num_channels, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) - self.assertEqual( - encoding["labels"].shape, - ( - 1, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) - self.assertEqual(encoding["labels"].dtype, torch.long) - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) + # Test not batched input + encoding = image_processor(image_inputs[0], maps[0], return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) - # Test batched - encoding = image_processor(image_inputs, maps, return_tensors="pt") - self.assertEqual( - encoding["pixel_values"].shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.num_channels, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) - self.assertEqual( - encoding["labels"].shape, - ( - self.image_processor_tester.batch_size, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) - self.assertEqual(encoding["labels"].dtype, torch.long) - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) + # Test batched + encoding = image_processor(image_inputs, maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + self.image_processor_tester.batch_size, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) - # Test not batched input (PIL images) - image, segmentation_map = prepare_semantic_single_inputs() + # Test not batched input (PIL images) + image, segmentation_map = prepare_semantic_single_inputs() - encoding = image_processor(image, segmentation_map, return_tensors="pt") - self.assertEqual( - encoding["pixel_values"].shape, - ( - 1, - self.image_processor_tester.num_channels, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) - self.assertEqual( - encoding["labels"].shape, - ( - 1, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) - self.assertEqual(encoding["labels"].dtype, torch.long) - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) + encoding = image_processor(image, segmentation_map, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 1, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 1, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) - # Test batched input (PIL images) - images, segmentation_maps = prepare_semantic_batch_inputs() + # Test batched input (PIL images) + images, segmentation_maps = prepare_semantic_batch_inputs() - encoding = image_processor(images, segmentation_maps, return_tensors="pt") - self.assertEqual( - encoding["pixel_values"].shape, - ( - 2, - self.image_processor_tester.num_channels, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) - self.assertEqual( - encoding["labels"].shape, - ( - 2, - self.image_processor_tester.size["height"], - self.image_processor_tester.size["width"], - ), - ) - self.assertEqual(encoding["labels"].dtype, torch.long) - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) + encoding = image_processor(images, segmentation_maps, return_tensors="pt") + self.assertEqual( + encoding["pixel_values"].shape, + ( + 2, + self.image_processor_tester.num_channels, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual( + encoding["labels"].shape, + ( + 2, + self.image_processor_tester.size["height"], + self.image_processor_tester.size["width"], + ), + ) + self.assertEqual(encoding["labels"].dtype, torch.long) + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) - # Copied from transformers.tests.models.beit.test_image_processing_beit.BeitImageProcessingTest.test_reduce_labels def test_reduce_labels(self): - # Initialize image_processor - image_processor = self.image_processing_class(**self.image_processor_dict) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) - # ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150 - image, map = prepare_semantic_single_inputs() - encoding = image_processor(image, map, return_tensors="pt") - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 150) + # ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150 + image, map = prepare_semantic_single_inputs() + encoding = image_processor(image, map, return_tensors="pt") + labels_no_reduce = encoding["labels"].clone() + self.assertTrue(labels_no_reduce.min().item() >= 0) + self.assertTrue(labels_no_reduce.max().item() <= 150) + # Get the first non-zero label coords and value, for comparison when do_reduce_labels is True + non_zero_positions = (labels_no_reduce > 0).nonzero() + first_non_zero_coords = tuple(non_zero_positions[0].tolist()) + first_non_zero_value = labels_no_reduce[first_non_zero_coords].item() - image_processor.do_reduce_labels = True - encoding = image_processor(image, map, return_tensors="pt") - self.assertTrue(encoding["labels"].min().item() >= 0) - self.assertTrue(encoding["labels"].max().item() <= 255) + image_processor.do_reduce_labels = True + encoding = image_processor(image, map, return_tensors="pt") + self.assertTrue(encoding["labels"].min().item() >= 0) + self.assertTrue(encoding["labels"].max().item() <= 255) + # Compare with non-reduced label to see if it's reduced by 1 + self.assertEqual(encoding["labels"][first_non_zero_coords].item(), first_non_zero_value - 1) + + def test_slow_fast_equivalence(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + dummy_image, dummy_map = prepare_semantic_single_inputs() + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + image_encoding_slow = image_processor_slow(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") + image_encoding_fast = image_processor_fast(dummy_image, segmentation_maps=dummy_map, return_tensors="pt") + + self.assertTrue(torch.allclose(image_encoding_slow.pixel_values, image_encoding_fast.pixel_values, atol=1e-1)) + self.assertLessEqual( + torch.mean(torch.abs(image_encoding_slow.pixel_values - image_encoding_fast.pixel_values)).item(), 1e-3 + ) + self.assertTrue(torch.allclose(image_encoding_slow.labels, image_encoding_fast.labels, atol=1e-1)) + + def test_slow_fast_equivalence_batched(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: + self.skipTest( + reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" + ) + + dummy_images, dummy_maps = prepare_semantic_batch_inputs() + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_images, segmentation_maps=dummy_maps, return_tensors="pt") + + self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1)) + self.assertLessEqual( + torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3 + )