added fast image processor for ZoeDepth and expanded tests accordingly (#38515)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run

* added fast image processor for ZoeDepth and expanded tests accordingly

* added fast image processor for ZoeDepth and expanded tests accordingly, hopefully fixed repo consistency issue too now

* final edits for zoedept fast image processor

* final minor edit for zoedepth fast imate procesor
This commit is contained in:
Henrik Matthiesen 2025-06-05 00:59:17 +02:00 committed by GitHub
parent a510be20f3
commit 1fed6166c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 435 additions and 36 deletions

View File

@ -119,6 +119,11 @@ Image.fromarray(depth.astype("uint8"))
[[autodoc]] ZoeDepthImageProcessor
- preprocess
## ZoeDepthImageProcessorFast
[[autodoc]] ZoeDepthImageProcessorFast
- preprocess
## ZoeDepthForDepthEstimation
[[autodoc]] ZoeDepthForDepthEstimation

View File

@ -170,7 +170,7 @@ else:
("vitmatte", ("VitMatteImageProcessor", "VitMatteImageProcessorFast")),
("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("yolos", ("YolosImageProcessor", "YolosImageProcessorFast")),
("zoedepth", ("ZoeDepthImageProcessor",)),
("zoedepth", ("ZoeDepthImageProcessor", "ZoeDepthImageProcessorFast")),
]
)

View File

@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_zoedepth import *
from .image_processing_zoedepth import *
from .image_processing_zoedepth_fast import *
from .modeling_zoedepth import *
else:
import sys

View File

@ -0,0 +1,328 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for ZoeDepth."""
from typing import (
Dict,
List,
Optional,
Tuple,
Union,
)
import numpy as np
from ...image_processing_utils import (
BatchFeature,
)
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
get_image_size,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
auto_docstring,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
logging,
requires_backends,
)
from .image_processing_zoedepth import get_resize_output_image_size
from .modeling_zoedepth import ZoeDepthDepthEstimatorOutput
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationMode
logger = logging.get_logger(__name__)
class ZoeDepthFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
"""
do_pad (`bool`, *optional*, defaults to `True`):
Whether to apply pad the input.
keep_aspect_ratio (`bool`, *optional*, defaults to `True`):
If `True`, the image is resized by choosing the smaller of the height and width scaling factors and using it
for both dimensions. This ensures that the image is scaled down as little as possible while still fitting
within the desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a
size that is a multiple of this value by flooring the height and width to the nearest multiple of this value.
Can be overridden by `keep_aspect_ratio` in `preprocess`.
ensure_multiple_of (`int`, *optional*, defaults to 32):
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by flooring
the height and width to the nearest multiple of this value.
Works both with and without `keep_aspect_ratio` being set to `True`.
Can be overridden by `ensure_multiple_of` in `preprocess`.
"""
do_pad: Optional[bool]
keep_aspect_ratio: Optional[bool]
ensure_multiple_of: Optional[int]
@auto_docstring
class ZoeDepthImageProcessorFast(BaseImageProcessorFast):
do_pad = True
do_rescale = True
do_normalize = True
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
do_resize = True
size = {"height": 384, "width": 512}
resample = PILImageResampling.BILINEAR
keep_aspect_ratio = True
ensure_multiple_of = 1 / 32
valid_kwargs = ZoeDepthFastImageProcessorKwargs
def __init__(self, **kwargs: Unpack[ZoeDepthFastImageProcessorKwargs]) -> None:
super().__init__(**kwargs)
@auto_docstring
def preprocess(
self,
images: ImageInput,
**kwargs: Unpack[ZoeDepthFastImageProcessorKwargs],
) -> BatchFeature:
return super().preprocess(images, **kwargs)
def resize(
self,
images: "torch.Tensor",
size: SizeDict,
keep_aspect_ratio: bool = False,
ensure_multiple_of: int = 1,
interpolation: Optional["F.InterpolationMode"] = None,
) -> "torch.Tensor":
"""
Resize an image or batchd images to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
set, the image is resized to a size that is a multiple of this value.
Args:
images (`torch.Tensor`):
Images to resize.
size (`Dict[str, int]`):
Target size of the output image.
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
ensure_multiple_of (`int`, *optional*, defaults to 1):
The image is resized to a size that is a multiple of this value.
interpolation (`F.InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
specified in `size`.
"""
if not size.height or not size.width:
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size}")
output_size = get_resize_output_image_size(
images,
output_size=(size.height, size.width),
keep_aspect_ratio=keep_aspect_ratio,
multiple=ensure_multiple_of,
input_data_format=ChannelDimension.FIRST,
)
height, width = output_size
resized_images = torch.nn.functional.interpolate(
images, (int(height), int(width)), mode=interpolation.value, align_corners=True
)
return resized_images
def _pad_images(
self,
images: "torch.Tensor",
):
"""
Args:
image (`torch.Tensor`):
Image to pad.
"""
height, width = get_image_size(images, channel_dim=ChannelDimension.FIRST)
pad_height = int(np.sqrt(height / 2) * 3)
pad_width = int(np.sqrt(width / 2) * 3)
return F.pad(images, padding=(pad_width, pad_height), padding_mode="reflect")
def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
keep_aspect_ratio: Optional[bool],
ensure_multiple_of: Optional[int],
interpolation: Optional["F.InterpolationMode"],
do_pad: bool,
do_rescale: bool,
rescale_factor: Optional[float],
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_rescale:
stacked_images = self.rescale(stacked_images, rescale_factor)
if do_pad:
stacked_images = self._pad_images(images=stacked_images)
if do_resize:
stacked_images = self.resize(
stacked_images, size, keep_aspect_ratio, ensure_multiple_of, interpolation
)
if do_normalize:
stacked_images = self.normalize(stacked_images, image_mean, image_std)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
processed_images = torch.stack(resized_images, dim=0) if return_tensors else resized_images
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
def post_process_depth_estimation(
self,
outputs: "ZoeDepthDepthEstimatorOutput",
source_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
outputs_flipped: Optional[Union["ZoeDepthDepthEstimatorOutput", None]] = None,
do_remove_padding: Optional[Union[bool, None]] = None,
) -> List[Dict[str, TensorType]]:
"""
Converts the raw output of [`ZoeDepthDepthEstimatorOutput`] into final depth predictions and depth PIL images.
Only supports PyTorch.
Args:
outputs ([`ZoeDepthDepthEstimatorOutput`]):
Raw outputs of the model.
source_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the source size
(height, width) of each image in the batch before preprocessing. This argument should be dealt as
"required" unless the user passes `do_remove_padding=False` as input to this function.
target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized.
outputs_flipped ([`ZoeDepthDepthEstimatorOutput`], *optional*):
Raw outputs of the model from flipped input (averaged out in the end).
do_remove_padding (`bool`, *optional*):
By default ZoeDepth adds padding equal to `int((height / 2) * 3)` (and similarly for width) to fix the
boundary artifacts in the output depth map, so we need remove this padding during post_processing. The
parameter exists here in case the user changed the image preprocessing to not include padding.
Returns:
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
predictions.
"""
requires_backends(self, "torch")
predicted_depth = outputs.predicted_depth
if (outputs_flipped is not None) and (predicted_depth.shape != outputs_flipped.predicted_depth.shape):
raise ValueError("Make sure that `outputs` and `outputs_flipped` have the same shape")
if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
)
if do_remove_padding is None:
do_remove_padding = self.do_pad
if source_sizes is None and do_remove_padding:
raise ValueError(
"Either `source_sizes` should be passed in, or `do_remove_padding` should be set to False"
)
if (source_sizes is not None) and (len(predicted_depth) != len(source_sizes)):
raise ValueError(
"Make sure that you pass in as many source image sizes as the batch dimension of the logits"
)
if outputs_flipped is not None:
predicted_depth = (predicted_depth + torch.flip(outputs_flipped.predicted_depth, dims=[-1])) / 2
predicted_depth = predicted_depth.unsqueeze(1)
# Zoe Depth model adds padding around the images to fix the boundary artifacts in the output depth map
# The padding length is `int(np.sqrt(img_h/2) * fh)` for the height and similar for the width
# fh (and fw respectively) are equal to '3' by default
# Check [here](https://github.com/isl-org/ZoeDepth/blob/edb6daf45458569e24f50250ef1ed08c015f17a7/zoedepth/models/depth_model.py#L57)
# for the original implementation.
# In this section, we remove this padding to get the final depth image and depth prediction
padding_factor_h = padding_factor_w = 3
results = []
target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
source_sizes = [None] * len(predicted_depth) if source_sizes is None else source_sizes
for depth, target_size, source_size in zip(predicted_depth, target_sizes, source_sizes):
# depth.shape = [1, H, W]
if source_size is not None:
pad_h = pad_w = 0
if do_remove_padding:
pad_h = int(np.sqrt(source_size[0] / 2) * padding_factor_h)
pad_w = int(np.sqrt(source_size[1] / 2) * padding_factor_w)
depth = F.resize(
depth,
size=[source_size[0] + 2 * pad_h, source_size[1] + 2 * pad_w],
interpolation=InterpolationMode.BICUBIC,
antialias=False,
)
if pad_h > 0:
depth = depth[:, pad_h:-pad_h, :]
if pad_w > 0:
depth = depth[:, :, pad_w:-pad_w]
if target_size is not None:
target_size = [target_size[0], target_size[1]]
depth = F.resize(
depth,
size=target_size,
interpolation=InterpolationMode.BICUBIC,
antialias=False,
)
depth = depth.squeeze(0)
# depth.shape = [H, W]
results.append({"predicted_depth": depth})
return results
__all__ = ["ZoeDepthImageProcessorFast"]

View File

@ -14,18 +14,30 @@
import unittest
from dataclasses import dataclass
import numpy as np
from transformers.file_utils import is_vision_available
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from transformers import ZoeDepthImageProcessor
if is_torchvision_available():
from transformers import ZoeDepthImageProcessorFast
@dataclass
class ZoeDepthDepthOutputProxy:
predicted_depth: torch.FloatTensor = None
class ZoeDepthImageProcessingTester:
def __init__(
@ -43,7 +55,7 @@ class ZoeDepthImageProcessingTester:
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
do_pad=False,
do_pad=True,
):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
@ -87,11 +99,25 @@ class ZoeDepthImageProcessingTester:
torchify=torchify,
)
def prepare_depth_outputs(self):
depth_tensors = prepare_image_inputs(
batch_size=self.batch_size,
num_channels=1,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=True,
torchify=True,
)
depth_tensors = [depth_tensor.squeeze(0) for depth_tensor in depth_tensors]
stacked_depth_tensors = torch.stack(depth_tensors, dim=0)
return ZoeDepthDepthOutputProxy(predicted_depth=stacked_depth_tensors)
@require_torch
@require_vision
class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = ZoeDepthImageProcessor if is_vision_available() else None
fast_image_processing_class = ZoeDepthImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@ -115,10 +141,14 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self.assertTrue(hasattr(image_processing, "do_pad"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
for image_processing_class in self.image_processor_list:
modified_dict = self.image_processor_dict
modified_dict["size"] = 42
image_processor = image_processing_class(**modified_dict)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
def test_ensure_multiple_of(self):
@ -127,7 +157,8 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
size = {"height": 380, "width": 513}
multiple = 32
image_processor = ZoeDepthImageProcessor(
for image_processor_class in self.image_processor_list:
image_processor = image_processor_class(
do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False
)
pixel_values = image_processor(image, return_tensors="pt").pixel_values
@ -142,7 +173,8 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
height, width = 512, 512
size = {"height": height, "width": width}
multiple = 32
image_processor = ZoeDepthImageProcessor(
for image_processor_class in self.image_processor_list:
image_processor = image_processor_class(
do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False
)
pixel_values = image_processor(image, return_tensors="pt").pixel_values
@ -157,14 +189,18 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image = np.zeros((height, width, 3))
size = {"height": 512, "width": 512}
image_processor = ZoeDepthImageProcessor(do_pad=False, keep_aspect_ratio=True, size=size, ensure_multiple_of=1)
for image_processor_class in self.image_processor_list:
image_processor = image_processor_class(
do_pad=False, keep_aspect_ratio=True, size=size, ensure_multiple_of=1
)
pixel_values = image_processor(image, return_tensors="pt").pixel_values
# As can be seen, the image is resized to the maximum size that fits in the specified size
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 670])
# Test `keep_aspect_ratio=False` by turning off all other variables which affect the size
image_processor = ZoeDepthImageProcessor(
for image_processor_class in self.image_processor_list:
image_processor = image_processor_class(
do_pad=False, keep_aspect_ratio=False, size=size, ensure_multiple_of=1
)
pixel_values = image_processor(image, return_tensors="pt").pixel_values
@ -177,10 +213,39 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
size = {"height": 511, "width": 511}
multiple = 32
image_processor = ZoeDepthImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=multiple)
for image_processor_class in self.image_processor_list:
image_processor = image_processor_class(size=size, keep_aspect_ratio=True, ensure_multiple_of=multiple)
pixel_values = image_processor(image, return_tensors="pt").pixel_values
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672])
self.assertTrue(pixel_values.shape[2] % multiple == 0)
self.assertTrue(pixel_values.shape[3] % multiple == 0)
# extend this test to check if removal of padding works fine!
def test_post_processing_equivalence(self):
outputs = self.image_processor_tester.prepare_depth_outputs()
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
source_sizes = [outputs.predicted_depth.shape[1:]] * self.image_processor_tester.batch_size
target_sizes = [
torch.Size([outputs.predicted_depth.shape[1] // 2, *(outputs.predicted_depth.shape[2:])])
] * self.image_processor_tester.batch_size
processed_fast = image_processor_fast.post_process_depth_estimation(
outputs,
source_sizes=source_sizes,
target_sizes=target_sizes,
)
processed_slow = image_processor_slow.post_process_depth_estimation(
outputs,
source_sizes=source_sizes,
target_sizes=target_sizes,
)
for pred_fast, pred_slow in zip(processed_fast, processed_slow):
depth_fast = pred_fast["predicted_depth"]
depth_slow = pred_slow["predicted_depth"]
torch.testing.assert_close(depth_fast, depth_slow, atol=1e-1, rtol=1e-3)
self.assertLessEqual(torch.mean(torch.abs(depth_fast.float() - depth_slow.float())).item(), 5e-3)