From 1fed6166c00b800330fcda8494f78cbcad8e4e3b Mon Sep 17 00:00:00 2001 From: Henrik Matthiesen <126027334+henrikm11@users.noreply.github.com> Date: Thu, 5 Jun 2025 00:59:17 +0200 Subject: [PATCH] added fast image processor for ZoeDepth and expanded tests accordingly (#38515) * added fast image processor for ZoeDepth and expanded tests accordingly * added fast image processor for ZoeDepth and expanded tests accordingly, hopefully fixed repo consistency issue too now * final edits for zoedept fast image processor * final minor edit for zoedepth fast imate procesor --- docs/source/en/model_doc/zoedepth.md | 5 + .../models/auto/image_processing_auto.py | 2 +- src/transformers/models/zoedepth/__init__.py | 1 + .../image_processing_zoedepth_fast.py | 328 ++++++++++++++++++ .../test_image_processing_zoedepth.py | 135 +++++-- 5 files changed, 435 insertions(+), 36 deletions(-) create mode 100644 src/transformers/models/zoedepth/image_processing_zoedepth_fast.py diff --git a/docs/source/en/model_doc/zoedepth.md b/docs/source/en/model_doc/zoedepth.md index 59bc483d8cf..d392b34abba 100644 --- a/docs/source/en/model_doc/zoedepth.md +++ b/docs/source/en/model_doc/zoedepth.md @@ -119,6 +119,11 @@ Image.fromarray(depth.astype("uint8")) [[autodoc]] ZoeDepthImageProcessor - preprocess +## ZoeDepthImageProcessorFast + +[[autodoc]] ZoeDepthImageProcessorFast + - preprocess + ## ZoeDepthForDepthEstimation [[autodoc]] ZoeDepthForDepthEstimation diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 059c4c40e68..f992369f27c 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -170,7 +170,7 @@ else: ("vitmatte", ("VitMatteImageProcessor", "VitMatteImageProcessorFast")), ("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("yolos", ("YolosImageProcessor", "YolosImageProcessorFast")), - ("zoedepth", ("ZoeDepthImageProcessor",)), + ("zoedepth", ("ZoeDepthImageProcessor", "ZoeDepthImageProcessorFast")), ] ) diff --git a/src/transformers/models/zoedepth/__init__.py b/src/transformers/models/zoedepth/__init__.py index 99879e0f85c..abc436fa801 100644 --- a/src/transformers/models/zoedepth/__init__.py +++ b/src/transformers/models/zoedepth/__init__.py @@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure if TYPE_CHECKING: from .configuration_zoedepth import * from .image_processing_zoedepth import * + from .image_processing_zoedepth_fast import * from .modeling_zoedepth import * else: import sys diff --git a/src/transformers/models/zoedepth/image_processing_zoedepth_fast.py b/src/transformers/models/zoedepth/image_processing_zoedepth_fast.py new file mode 100644 index 00000000000..abc72cd8cd1 --- /dev/null +++ b/src/transformers/models/zoedepth/image_processing_zoedepth_fast.py @@ -0,0 +1,328 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for ZoeDepth.""" + +from typing import ( + Dict, + List, + Optional, + Tuple, + Union, +) + +import numpy as np + +from ...image_processing_utils import ( + BatchFeature, +) +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + get_image_size, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + logging, + requires_backends, +) +from .image_processing_zoedepth import get_resize_output_image_size +from .modeling_zoedepth import ZoeDepthDepthEstimatorOutput + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + from torchvision.transforms import InterpolationMode + + +logger = logging.get_logger(__name__) + + +class ZoeDepthFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + """ + do_pad (`bool`, *optional*, defaults to `True`): + Whether to apply pad the input. + keep_aspect_ratio (`bool`, *optional*, defaults to `True`): + If `True`, the image is resized by choosing the smaller of the height and width scaling factors and using it + for both dimensions. This ensures that the image is scaled down as little as possible while still fitting + within the desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a + size that is a multiple of this value by flooring the height and width to the nearest multiple of this value. + Can be overridden by `keep_aspect_ratio` in `preprocess`. + ensure_multiple_of (`int`, *optional*, defaults to 32): + If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by flooring + the height and width to the nearest multiple of this value. + Works both with and without `keep_aspect_ratio` being set to `True`. + Can be overridden by `ensure_multiple_of` in `preprocess`. + """ + + do_pad: Optional[bool] + keep_aspect_ratio: Optional[bool] + ensure_multiple_of: Optional[int] + + +@auto_docstring +class ZoeDepthImageProcessorFast(BaseImageProcessorFast): + do_pad = True + do_rescale = True + do_normalize = True + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + do_resize = True + size = {"height": 384, "width": 512} + resample = PILImageResampling.BILINEAR + keep_aspect_ratio = True + ensure_multiple_of = 1 / 32 + valid_kwargs = ZoeDepthFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[ZoeDepthFastImageProcessorKwargs]) -> None: + super().__init__(**kwargs) + + @auto_docstring + def preprocess( + self, + images: ImageInput, + **kwargs: Unpack[ZoeDepthFastImageProcessorKwargs], + ) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def resize( + self, + images: "torch.Tensor", + size: SizeDict, + keep_aspect_ratio: bool = False, + ensure_multiple_of: int = 1, + interpolation: Optional["F.InterpolationMode"] = None, + ) -> "torch.Tensor": + """ + Resize an image or batchd images to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image + is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is + set, the image is resized to a size that is a multiple of this value. + + Args: + images (`torch.Tensor`): + Images to resize. + size (`Dict[str, int]`): + Target size of the output image. + keep_aspect_ratio (`bool`, *optional*, defaults to `False`): + If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. + ensure_multiple_of (`int`, *optional*, defaults to 1): + The image is resized to a size that is a multiple of this value. + interpolation (`F.InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size + specified in `size`. + """ + if not size.height or not size.width: + raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size}") + output_size = get_resize_output_image_size( + images, + output_size=(size.height, size.width), + keep_aspect_ratio=keep_aspect_ratio, + multiple=ensure_multiple_of, + input_data_format=ChannelDimension.FIRST, + ) + height, width = output_size + + resized_images = torch.nn.functional.interpolate( + images, (int(height), int(width)), mode=interpolation.value, align_corners=True + ) + + return resized_images + + def _pad_images( + self, + images: "torch.Tensor", + ): + """ + Args: + image (`torch.Tensor`): + Image to pad. + """ + height, width = get_image_size(images, channel_dim=ChannelDimension.FIRST) + + pad_height = int(np.sqrt(height / 2) * 3) + pad_width = int(np.sqrt(width / 2) * 3) + + return F.pad(images, padding=(pad_width, pad_height), padding_mode="reflect") + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + keep_aspect_ratio: Optional[bool], + ensure_multiple_of: Optional[int], + interpolation: Optional["F.InterpolationMode"], + do_pad: bool, + do_rescale: bool, + rescale_factor: Optional[float], + do_normalize: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_rescale: + stacked_images = self.rescale(stacked_images, rescale_factor) + if do_pad: + stacked_images = self._pad_images(images=stacked_images) + if do_resize: + stacked_images = self.resize( + stacked_images, size, keep_aspect_ratio, ensure_multiple_of, interpolation + ) + if do_normalize: + stacked_images = self.normalize(stacked_images, image_mean, image_std) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + processed_images = torch.stack(resized_images, dim=0) if return_tensors else resized_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + def post_process_depth_estimation( + self, + outputs: "ZoeDepthDepthEstimatorOutput", + source_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None, + target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None, + outputs_flipped: Optional[Union["ZoeDepthDepthEstimatorOutput", None]] = None, + do_remove_padding: Optional[Union[bool, None]] = None, + ) -> List[Dict[str, TensorType]]: + """ + Converts the raw output of [`ZoeDepthDepthEstimatorOutput`] into final depth predictions and depth PIL images. + Only supports PyTorch. + + Args: + outputs ([`ZoeDepthDepthEstimatorOutput`]): + Raw outputs of the model. + source_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the source size + (height, width) of each image in the batch before preprocessing. This argument should be dealt as + "required" unless the user passes `do_remove_padding=False` as input to this function. + target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + (height, width) of each image in the batch. If left to None, predictions will not be resized. + outputs_flipped ([`ZoeDepthDepthEstimatorOutput`], *optional*): + Raw outputs of the model from flipped input (averaged out in the end). + do_remove_padding (`bool`, *optional*): + By default ZoeDepth adds padding equal to `int(√(height / 2) * 3)` (and similarly for width) to fix the + boundary artifacts in the output depth map, so we need remove this padding during post_processing. The + parameter exists here in case the user changed the image preprocessing to not include padding. + + Returns: + `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth + predictions. + """ + requires_backends(self, "torch") + + predicted_depth = outputs.predicted_depth + + if (outputs_flipped is not None) and (predicted_depth.shape != outputs_flipped.predicted_depth.shape): + raise ValueError("Make sure that `outputs` and `outputs_flipped` have the same shape") + + if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth" + ) + + if do_remove_padding is None: + do_remove_padding = self.do_pad + + if source_sizes is None and do_remove_padding: + raise ValueError( + "Either `source_sizes` should be passed in, or `do_remove_padding` should be set to False" + ) + + if (source_sizes is not None) and (len(predicted_depth) != len(source_sizes)): + raise ValueError( + "Make sure that you pass in as many source image sizes as the batch dimension of the logits" + ) + + if outputs_flipped is not None: + predicted_depth = (predicted_depth + torch.flip(outputs_flipped.predicted_depth, dims=[-1])) / 2 + + predicted_depth = predicted_depth.unsqueeze(1) + + # Zoe Depth model adds padding around the images to fix the boundary artifacts in the output depth map + # The padding length is `int(np.sqrt(img_h/2) * fh)` for the height and similar for the width + # fh (and fw respectively) are equal to '3' by default + # Check [here](https://github.com/isl-org/ZoeDepth/blob/edb6daf45458569e24f50250ef1ed08c015f17a7/zoedepth/models/depth_model.py#L57) + # for the original implementation. + # In this section, we remove this padding to get the final depth image and depth prediction + padding_factor_h = padding_factor_w = 3 + + results = [] + target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes + source_sizes = [None] * len(predicted_depth) if source_sizes is None else source_sizes + for depth, target_size, source_size in zip(predicted_depth, target_sizes, source_sizes): + # depth.shape = [1, H, W] + if source_size is not None: + pad_h = pad_w = 0 + + if do_remove_padding: + pad_h = int(np.sqrt(source_size[0] / 2) * padding_factor_h) + pad_w = int(np.sqrt(source_size[1] / 2) * padding_factor_w) + + depth = F.resize( + depth, + size=[source_size[0] + 2 * pad_h, source_size[1] + 2 * pad_w], + interpolation=InterpolationMode.BICUBIC, + antialias=False, + ) + + if pad_h > 0: + depth = depth[:, pad_h:-pad_h, :] + if pad_w > 0: + depth = depth[:, :, pad_w:-pad_w] + + if target_size is not None: + target_size = [target_size[0], target_size[1]] + depth = F.resize( + depth, + size=target_size, + interpolation=InterpolationMode.BICUBIC, + antialias=False, + ) + depth = depth.squeeze(0) + # depth.shape = [H, W] + results.append({"predicted_depth": depth}) + + return results + + +__all__ = ["ZoeDepthImageProcessorFast"] diff --git a/tests/models/zoedepth/test_image_processing_zoedepth.py b/tests/models/zoedepth/test_image_processing_zoedepth.py index 69558ad3c47..5e5ec4c0425 100644 --- a/tests/models/zoedepth/test_image_processing_zoedepth.py +++ b/tests/models/zoedepth/test_image_processing_zoedepth.py @@ -14,18 +14,30 @@ import unittest +from dataclasses import dataclass import numpy as np -from transformers.file_utils import is_vision_available from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs +if is_torch_available(): + import torch + if is_vision_available(): from transformers import ZoeDepthImageProcessor + if is_torchvision_available(): + from transformers import ZoeDepthImageProcessorFast + + +@dataclass +class ZoeDepthDepthOutputProxy: + predicted_depth: torch.FloatTensor = None + class ZoeDepthImageProcessingTester: def __init__( @@ -43,7 +55,7 @@ class ZoeDepthImageProcessingTester: do_normalize=True, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5], - do_pad=False, + do_pad=True, ): size = size if size is not None else {"height": 18, "width": 18} self.parent = parent @@ -87,11 +99,25 @@ class ZoeDepthImageProcessingTester: torchify=torchify, ) + def prepare_depth_outputs(self): + depth_tensors = prepare_image_inputs( + batch_size=self.batch_size, + num_channels=1, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=True, + torchify=True, + ) + depth_tensors = [depth_tensor.squeeze(0) for depth_tensor in depth_tensors] + stacked_depth_tensors = torch.stack(depth_tensors, dim=0) + return ZoeDepthDepthOutputProxy(predicted_depth=stacked_depth_tensors) + @require_torch @require_vision class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = ZoeDepthImageProcessor if is_vision_available() else None + fast_image_processing_class = ZoeDepthImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -115,11 +141,15 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): self.assertTrue(hasattr(image_processing, "do_pad")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 18, "width": 18}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 18, "width": 18}) - image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42) - self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + for image_processing_class in self.image_processor_list: + modified_dict = self.image_processor_dict + modified_dict["size"] = 42 + image_processor = image_processing_class(**modified_dict) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) def test_ensure_multiple_of(self): # Test variable by turning off all other variables which affect the size, size which is not multiple of 32 @@ -127,14 +157,15 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): size = {"height": 380, "width": 513} multiple = 32 - image_processor = ZoeDepthImageProcessor( - do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False - ) - pixel_values = image_processor(image, return_tensors="pt").pixel_values + for image_processor_class in self.image_processor_list: + image_processor = image_processor_class( + do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False + ) + pixel_values = image_processor(image, return_tensors="pt").pixel_values - self.assertEqual(list(pixel_values.shape), [1, 3, 384, 512]) - self.assertTrue(pixel_values.shape[2] % multiple == 0) - self.assertTrue(pixel_values.shape[3] % multiple == 0) + self.assertEqual(list(pixel_values.shape), [1, 3, 384, 512]) + self.assertTrue(pixel_values.shape[2] % multiple == 0) + self.assertTrue(pixel_values.shape[3] % multiple == 0) # Test variable by turning off all other variables which affect the size, size which is already multiple of 32 image = np.zeros((511, 511, 3)) @@ -142,14 +173,15 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): height, width = 512, 512 size = {"height": height, "width": width} multiple = 32 - image_processor = ZoeDepthImageProcessor( - do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False - ) - pixel_values = image_processor(image, return_tensors="pt").pixel_values + for image_processor_class in self.image_processor_list: + image_processor = image_processor_class( + do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False + ) + pixel_values = image_processor(image, return_tensors="pt").pixel_values - self.assertEqual(list(pixel_values.shape), [1, 3, height, width]) - self.assertTrue(pixel_values.shape[2] % multiple == 0) - self.assertTrue(pixel_values.shape[3] % multiple == 0) + self.assertEqual(list(pixel_values.shape), [1, 3, height, width]) + self.assertTrue(pixel_values.shape[2] % multiple == 0) + self.assertTrue(pixel_values.shape[3] % multiple == 0) def test_keep_aspect_ratio(self): # Test `keep_aspect_ratio=True` by turning off all other variables which affect the size @@ -157,30 +189,63 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image = np.zeros((height, width, 3)) size = {"height": 512, "width": 512} - image_processor = ZoeDepthImageProcessor(do_pad=False, keep_aspect_ratio=True, size=size, ensure_multiple_of=1) - pixel_values = image_processor(image, return_tensors="pt").pixel_values + for image_processor_class in self.image_processor_list: + image_processor = image_processor_class( + do_pad=False, keep_aspect_ratio=True, size=size, ensure_multiple_of=1 + ) + pixel_values = image_processor(image, return_tensors="pt").pixel_values - # As can be seen, the image is resized to the maximum size that fits in the specified size - self.assertEqual(list(pixel_values.shape), [1, 3, 512, 670]) + # As can be seen, the image is resized to the maximum size that fits in the specified size + self.assertEqual(list(pixel_values.shape), [1, 3, 512, 670]) # Test `keep_aspect_ratio=False` by turning off all other variables which affect the size - image_processor = ZoeDepthImageProcessor( - do_pad=False, keep_aspect_ratio=False, size=size, ensure_multiple_of=1 - ) - pixel_values = image_processor(image, return_tensors="pt").pixel_values + for image_processor_class in self.image_processor_list: + image_processor = image_processor_class( + do_pad=False, keep_aspect_ratio=False, size=size, ensure_multiple_of=1 + ) + pixel_values = image_processor(image, return_tensors="pt").pixel_values - # As can be seen, the size is respected - self.assertEqual(list(pixel_values.shape), [1, 3, size["height"], size["width"]]) + # As can be seen, the size is respected + self.assertEqual(list(pixel_values.shape), [1, 3, size["height"], size["width"]]) # Test `keep_aspect_ratio=True` with `ensure_multiple_of` set image = np.zeros((489, 640, 3)) size = {"height": 511, "width": 511} multiple = 32 - image_processor = ZoeDepthImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=multiple) + for image_processor_class in self.image_processor_list: + image_processor = image_processor_class(size=size, keep_aspect_ratio=True, ensure_multiple_of=multiple) - pixel_values = image_processor(image, return_tensors="pt").pixel_values + pixel_values = image_processor(image, return_tensors="pt").pixel_values - self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672]) - self.assertTrue(pixel_values.shape[2] % multiple == 0) - self.assertTrue(pixel_values.shape[3] % multiple == 0) + self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672]) + self.assertTrue(pixel_values.shape[2] % multiple == 0) + self.assertTrue(pixel_values.shape[3] % multiple == 0) + + # extend this test to check if removal of padding works fine! + def test_post_processing_equivalence(self): + outputs = self.image_processor_tester.prepare_depth_outputs() + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + + source_sizes = [outputs.predicted_depth.shape[1:]] * self.image_processor_tester.batch_size + target_sizes = [ + torch.Size([outputs.predicted_depth.shape[1] // 2, *(outputs.predicted_depth.shape[2:])]) + ] * self.image_processor_tester.batch_size + + processed_fast = image_processor_fast.post_process_depth_estimation( + outputs, + source_sizes=source_sizes, + target_sizes=target_sizes, + ) + processed_slow = image_processor_slow.post_process_depth_estimation( + outputs, + source_sizes=source_sizes, + target_sizes=target_sizes, + ) + for pred_fast, pred_slow in zip(processed_fast, processed_slow): + depth_fast = pred_fast["predicted_depth"] + depth_slow = pred_slow["predicted_depth"] + + torch.testing.assert_close(depth_fast, depth_slow, atol=1e-1, rtol=1e-3) + self.assertLessEqual(torch.mean(torch.abs(depth_fast.float() - depth_slow.float())).item(), 5e-3)