mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-15 18:48:24 +06:00
added fast image processor for ZoeDepth and expanded tests accordingly (#38515)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
* added fast image processor for ZoeDepth and expanded tests accordingly * added fast image processor for ZoeDepth and expanded tests accordingly, hopefully fixed repo consistency issue too now * final edits for zoedept fast image processor * final minor edit for zoedepth fast imate procesor
This commit is contained in:
parent
a510be20f3
commit
1fed6166c0
@ -119,6 +119,11 @@ Image.fromarray(depth.astype("uint8"))
|
|||||||
[[autodoc]] ZoeDepthImageProcessor
|
[[autodoc]] ZoeDepthImageProcessor
|
||||||
- preprocess
|
- preprocess
|
||||||
|
|
||||||
|
## ZoeDepthImageProcessorFast
|
||||||
|
|
||||||
|
[[autodoc]] ZoeDepthImageProcessorFast
|
||||||
|
- preprocess
|
||||||
|
|
||||||
## ZoeDepthForDepthEstimation
|
## ZoeDepthForDepthEstimation
|
||||||
|
|
||||||
[[autodoc]] ZoeDepthForDepthEstimation
|
[[autodoc]] ZoeDepthForDepthEstimation
|
||||||
|
@ -170,7 +170,7 @@ else:
|
|||||||
("vitmatte", ("VitMatteImageProcessor", "VitMatteImageProcessorFast")),
|
("vitmatte", ("VitMatteImageProcessor", "VitMatteImageProcessorFast")),
|
||||||
("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
("xclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||||
("yolos", ("YolosImageProcessor", "YolosImageProcessorFast")),
|
("yolos", ("YolosImageProcessor", "YolosImageProcessorFast")),
|
||||||
("zoedepth", ("ZoeDepthImageProcessor",)),
|
("zoedepth", ("ZoeDepthImageProcessor", "ZoeDepthImageProcessorFast")),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_zoedepth import *
|
from .configuration_zoedepth import *
|
||||||
from .image_processing_zoedepth import *
|
from .image_processing_zoedepth import *
|
||||||
|
from .image_processing_zoedepth_fast import *
|
||||||
from .modeling_zoedepth import *
|
from .modeling_zoedepth import *
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
@ -0,0 +1,328 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Fast Image processor class for ZoeDepth."""
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ...image_processing_utils import (
|
||||||
|
BatchFeature,
|
||||||
|
)
|
||||||
|
from ...image_processing_utils_fast import (
|
||||||
|
BaseImageProcessorFast,
|
||||||
|
DefaultFastImageProcessorKwargs,
|
||||||
|
group_images_by_shape,
|
||||||
|
reorder_images,
|
||||||
|
)
|
||||||
|
from ...image_utils import (
|
||||||
|
IMAGENET_STANDARD_MEAN,
|
||||||
|
IMAGENET_STANDARD_STD,
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
PILImageResampling,
|
||||||
|
SizeDict,
|
||||||
|
get_image_size,
|
||||||
|
)
|
||||||
|
from ...processing_utils import Unpack
|
||||||
|
from ...utils import (
|
||||||
|
TensorType,
|
||||||
|
auto_docstring,
|
||||||
|
is_torch_available,
|
||||||
|
is_torchvision_available,
|
||||||
|
is_torchvision_v2_available,
|
||||||
|
logging,
|
||||||
|
requires_backends,
|
||||||
|
)
|
||||||
|
from .image_processing_zoedepth import get_resize_output_image_size
|
||||||
|
from .modeling_zoedepth import ZoeDepthDepthEstimatorOutput
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_torchvision_available():
|
||||||
|
if is_torchvision_v2_available():
|
||||||
|
from torchvision.transforms.v2 import functional as F
|
||||||
|
else:
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
from torchvision.transforms import InterpolationMode
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ZoeDepthFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||||
|
"""
|
||||||
|
do_pad (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to apply pad the input.
|
||||||
|
keep_aspect_ratio (`bool`, *optional*, defaults to `True`):
|
||||||
|
If `True`, the image is resized by choosing the smaller of the height and width scaling factors and using it
|
||||||
|
for both dimensions. This ensures that the image is scaled down as little as possible while still fitting
|
||||||
|
within the desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a
|
||||||
|
size that is a multiple of this value by flooring the height and width to the nearest multiple of this value.
|
||||||
|
Can be overridden by `keep_aspect_ratio` in `preprocess`.
|
||||||
|
ensure_multiple_of (`int`, *optional*, defaults to 32):
|
||||||
|
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by flooring
|
||||||
|
the height and width to the nearest multiple of this value.
|
||||||
|
Works both with and without `keep_aspect_ratio` being set to `True`.
|
||||||
|
Can be overridden by `ensure_multiple_of` in `preprocess`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
do_pad: Optional[bool]
|
||||||
|
keep_aspect_ratio: Optional[bool]
|
||||||
|
ensure_multiple_of: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
class ZoeDepthImageProcessorFast(BaseImageProcessorFast):
|
||||||
|
do_pad = True
|
||||||
|
do_rescale = True
|
||||||
|
do_normalize = True
|
||||||
|
image_mean = IMAGENET_STANDARD_MEAN
|
||||||
|
image_std = IMAGENET_STANDARD_STD
|
||||||
|
do_resize = True
|
||||||
|
size = {"height": 384, "width": 512}
|
||||||
|
resample = PILImageResampling.BILINEAR
|
||||||
|
keep_aspect_ratio = True
|
||||||
|
ensure_multiple_of = 1 / 32
|
||||||
|
valid_kwargs = ZoeDepthFastImageProcessorKwargs
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Unpack[ZoeDepthFastImageProcessorKwargs]) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images: ImageInput,
|
||||||
|
**kwargs: Unpack[ZoeDepthFastImageProcessorKwargs],
|
||||||
|
) -> BatchFeature:
|
||||||
|
return super().preprocess(images, **kwargs)
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
images: "torch.Tensor",
|
||||||
|
size: SizeDict,
|
||||||
|
keep_aspect_ratio: bool = False,
|
||||||
|
ensure_multiple_of: int = 1,
|
||||||
|
interpolation: Optional["F.InterpolationMode"] = None,
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
"""
|
||||||
|
Resize an image or batchd images to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
|
||||||
|
is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
|
||||||
|
set, the image is resized to a size that is a multiple of this value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`torch.Tensor`):
|
||||||
|
Images to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Target size of the output image.
|
||||||
|
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
|
||||||
|
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
|
||||||
|
ensure_multiple_of (`int`, *optional*, defaults to 1):
|
||||||
|
The image is resized to a size that is a multiple of this value.
|
||||||
|
interpolation (`F.InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
|
||||||
|
Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
|
||||||
|
specified in `size`.
|
||||||
|
"""
|
||||||
|
if not size.height or not size.width:
|
||||||
|
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size}")
|
||||||
|
output_size = get_resize_output_image_size(
|
||||||
|
images,
|
||||||
|
output_size=(size.height, size.width),
|
||||||
|
keep_aspect_ratio=keep_aspect_ratio,
|
||||||
|
multiple=ensure_multiple_of,
|
||||||
|
input_data_format=ChannelDimension.FIRST,
|
||||||
|
)
|
||||||
|
height, width = output_size
|
||||||
|
|
||||||
|
resized_images = torch.nn.functional.interpolate(
|
||||||
|
images, (int(height), int(width)), mode=interpolation.value, align_corners=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return resized_images
|
||||||
|
|
||||||
|
def _pad_images(
|
||||||
|
self,
|
||||||
|
images: "torch.Tensor",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
image (`torch.Tensor`):
|
||||||
|
Image to pad.
|
||||||
|
"""
|
||||||
|
height, width = get_image_size(images, channel_dim=ChannelDimension.FIRST)
|
||||||
|
|
||||||
|
pad_height = int(np.sqrt(height / 2) * 3)
|
||||||
|
pad_width = int(np.sqrt(width / 2) * 3)
|
||||||
|
|
||||||
|
return F.pad(images, padding=(pad_width, pad_height), padding_mode="reflect")
|
||||||
|
|
||||||
|
def _preprocess(
|
||||||
|
self,
|
||||||
|
images: list["torch.Tensor"],
|
||||||
|
do_resize: bool,
|
||||||
|
size: SizeDict,
|
||||||
|
keep_aspect_ratio: Optional[bool],
|
||||||
|
ensure_multiple_of: Optional[int],
|
||||||
|
interpolation: Optional["F.InterpolationMode"],
|
||||||
|
do_pad: bool,
|
||||||
|
do_rescale: bool,
|
||||||
|
rescale_factor: Optional[float],
|
||||||
|
do_normalize: bool,
|
||||||
|
image_mean: Optional[Union[float, List[float]]],
|
||||||
|
image_std: Optional[Union[float, List[float]]],
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchFeature:
|
||||||
|
# Group images by size for batched resizing
|
||||||
|
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||||
|
resized_images_grouped = {}
|
||||||
|
for shape, stacked_images in grouped_images.items():
|
||||||
|
if do_rescale:
|
||||||
|
stacked_images = self.rescale(stacked_images, rescale_factor)
|
||||||
|
if do_pad:
|
||||||
|
stacked_images = self._pad_images(images=stacked_images)
|
||||||
|
if do_resize:
|
||||||
|
stacked_images = self.resize(
|
||||||
|
stacked_images, size, keep_aspect_ratio, ensure_multiple_of, interpolation
|
||||||
|
)
|
||||||
|
if do_normalize:
|
||||||
|
stacked_images = self.normalize(stacked_images, image_mean, image_std)
|
||||||
|
resized_images_grouped[shape] = stacked_images
|
||||||
|
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||||
|
|
||||||
|
processed_images = torch.stack(resized_images, dim=0) if return_tensors else resized_images
|
||||||
|
|
||||||
|
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def post_process_depth_estimation(
|
||||||
|
self,
|
||||||
|
outputs: "ZoeDepthDepthEstimatorOutput",
|
||||||
|
source_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
|
||||||
|
target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
|
||||||
|
outputs_flipped: Optional[Union["ZoeDepthDepthEstimatorOutput", None]] = None,
|
||||||
|
do_remove_padding: Optional[Union[bool, None]] = None,
|
||||||
|
) -> List[Dict[str, TensorType]]:
|
||||||
|
"""
|
||||||
|
Converts the raw output of [`ZoeDepthDepthEstimatorOutput`] into final depth predictions and depth PIL images.
|
||||||
|
Only supports PyTorch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs ([`ZoeDepthDepthEstimatorOutput`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
source_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
|
||||||
|
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the source size
|
||||||
|
(height, width) of each image in the batch before preprocessing. This argument should be dealt as
|
||||||
|
"required" unless the user passes `do_remove_padding=False` as input to this function.
|
||||||
|
target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
|
||||||
|
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||||
|
(height, width) of each image in the batch. If left to None, predictions will not be resized.
|
||||||
|
outputs_flipped ([`ZoeDepthDepthEstimatorOutput`], *optional*):
|
||||||
|
Raw outputs of the model from flipped input (averaged out in the end).
|
||||||
|
do_remove_padding (`bool`, *optional*):
|
||||||
|
By default ZoeDepth adds padding equal to `int(√(height / 2) * 3)` (and similarly for width) to fix the
|
||||||
|
boundary artifacts in the output depth map, so we need remove this padding during post_processing. The
|
||||||
|
parameter exists here in case the user changed the image preprocessing to not include padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
|
||||||
|
predictions.
|
||||||
|
"""
|
||||||
|
requires_backends(self, "torch")
|
||||||
|
|
||||||
|
predicted_depth = outputs.predicted_depth
|
||||||
|
|
||||||
|
if (outputs_flipped is not None) and (predicted_depth.shape != outputs_flipped.predicted_depth.shape):
|
||||||
|
raise ValueError("Make sure that `outputs` and `outputs_flipped` have the same shape")
|
||||||
|
|
||||||
|
if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_remove_padding is None:
|
||||||
|
do_remove_padding = self.do_pad
|
||||||
|
|
||||||
|
if source_sizes is None and do_remove_padding:
|
||||||
|
raise ValueError(
|
||||||
|
"Either `source_sizes` should be passed in, or `do_remove_padding` should be set to False"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (source_sizes is not None) and (len(predicted_depth) != len(source_sizes)):
|
||||||
|
raise ValueError(
|
||||||
|
"Make sure that you pass in as many source image sizes as the batch dimension of the logits"
|
||||||
|
)
|
||||||
|
|
||||||
|
if outputs_flipped is not None:
|
||||||
|
predicted_depth = (predicted_depth + torch.flip(outputs_flipped.predicted_depth, dims=[-1])) / 2
|
||||||
|
|
||||||
|
predicted_depth = predicted_depth.unsqueeze(1)
|
||||||
|
|
||||||
|
# Zoe Depth model adds padding around the images to fix the boundary artifacts in the output depth map
|
||||||
|
# The padding length is `int(np.sqrt(img_h/2) * fh)` for the height and similar for the width
|
||||||
|
# fh (and fw respectively) are equal to '3' by default
|
||||||
|
# Check [here](https://github.com/isl-org/ZoeDepth/blob/edb6daf45458569e24f50250ef1ed08c015f17a7/zoedepth/models/depth_model.py#L57)
|
||||||
|
# for the original implementation.
|
||||||
|
# In this section, we remove this padding to get the final depth image and depth prediction
|
||||||
|
padding_factor_h = padding_factor_w = 3
|
||||||
|
|
||||||
|
results = []
|
||||||
|
target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
|
||||||
|
source_sizes = [None] * len(predicted_depth) if source_sizes is None else source_sizes
|
||||||
|
for depth, target_size, source_size in zip(predicted_depth, target_sizes, source_sizes):
|
||||||
|
# depth.shape = [1, H, W]
|
||||||
|
if source_size is not None:
|
||||||
|
pad_h = pad_w = 0
|
||||||
|
|
||||||
|
if do_remove_padding:
|
||||||
|
pad_h = int(np.sqrt(source_size[0] / 2) * padding_factor_h)
|
||||||
|
pad_w = int(np.sqrt(source_size[1] / 2) * padding_factor_w)
|
||||||
|
|
||||||
|
depth = F.resize(
|
||||||
|
depth,
|
||||||
|
size=[source_size[0] + 2 * pad_h, source_size[1] + 2 * pad_w],
|
||||||
|
interpolation=InterpolationMode.BICUBIC,
|
||||||
|
antialias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if pad_h > 0:
|
||||||
|
depth = depth[:, pad_h:-pad_h, :]
|
||||||
|
if pad_w > 0:
|
||||||
|
depth = depth[:, :, pad_w:-pad_w]
|
||||||
|
|
||||||
|
if target_size is not None:
|
||||||
|
target_size = [target_size[0], target_size[1]]
|
||||||
|
depth = F.resize(
|
||||||
|
depth,
|
||||||
|
size=target_size,
|
||||||
|
interpolation=InterpolationMode.BICUBIC,
|
||||||
|
antialias=False,
|
||||||
|
)
|
||||||
|
depth = depth.squeeze(0)
|
||||||
|
# depth.shape = [H, W]
|
||||||
|
results.append({"predicted_depth": depth})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ZoeDepthImageProcessorFast"]
|
@ -14,18 +14,30 @@
|
|||||||
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers.file_utils import is_vision_available
|
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
|
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||||
|
|
||||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from transformers import ZoeDepthImageProcessor
|
from transformers import ZoeDepthImageProcessor
|
||||||
|
|
||||||
|
if is_torchvision_available():
|
||||||
|
from transformers import ZoeDepthImageProcessorFast
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ZoeDepthDepthOutputProxy:
|
||||||
|
predicted_depth: torch.FloatTensor = None
|
||||||
|
|
||||||
|
|
||||||
class ZoeDepthImageProcessingTester:
|
class ZoeDepthImageProcessingTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -43,7 +55,7 @@ class ZoeDepthImageProcessingTester:
|
|||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
image_mean=[0.5, 0.5, 0.5],
|
image_mean=[0.5, 0.5, 0.5],
|
||||||
image_std=[0.5, 0.5, 0.5],
|
image_std=[0.5, 0.5, 0.5],
|
||||||
do_pad=False,
|
do_pad=True,
|
||||||
):
|
):
|
||||||
size = size if size is not None else {"height": 18, "width": 18}
|
size = size if size is not None else {"height": 18, "width": 18}
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
@ -87,11 +99,25 @@ class ZoeDepthImageProcessingTester:
|
|||||||
torchify=torchify,
|
torchify=torchify,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_depth_outputs(self):
|
||||||
|
depth_tensors = prepare_image_inputs(
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_channels=1,
|
||||||
|
min_resolution=self.min_resolution,
|
||||||
|
max_resolution=self.max_resolution,
|
||||||
|
equal_resolution=True,
|
||||||
|
torchify=True,
|
||||||
|
)
|
||||||
|
depth_tensors = [depth_tensor.squeeze(0) for depth_tensor in depth_tensors]
|
||||||
|
stacked_depth_tensors = torch.stack(depth_tensors, dim=0)
|
||||||
|
return ZoeDepthDepthOutputProxy(predicted_depth=stacked_depth_tensors)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_vision
|
@require_vision
|
||||||
class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||||
image_processing_class = ZoeDepthImageProcessor if is_vision_available() else None
|
image_processing_class = ZoeDepthImageProcessor if is_vision_available() else None
|
||||||
|
fast_image_processing_class = ZoeDepthImageProcessorFast if is_torchvision_available() else None
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
@ -115,10 +141,14 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
self.assertTrue(hasattr(image_processing, "do_pad"))
|
self.assertTrue(hasattr(image_processing, "do_pad"))
|
||||||
|
|
||||||
def test_image_processor_from_dict_with_kwargs(self):
|
def test_image_processor_from_dict_with_kwargs(self):
|
||||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
for image_processing_class in self.image_processor_list:
|
||||||
|
image_processor = image_processing_class(**self.image_processor_dict)
|
||||||
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
|
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
|
||||||
|
|
||||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
|
for image_processing_class in self.image_processor_list:
|
||||||
|
modified_dict = self.image_processor_dict
|
||||||
|
modified_dict["size"] = 42
|
||||||
|
image_processor = image_processing_class(**modified_dict)
|
||||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||||
|
|
||||||
def test_ensure_multiple_of(self):
|
def test_ensure_multiple_of(self):
|
||||||
@ -127,7 +157,8 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
|
|
||||||
size = {"height": 380, "width": 513}
|
size = {"height": 380, "width": 513}
|
||||||
multiple = 32
|
multiple = 32
|
||||||
image_processor = ZoeDepthImageProcessor(
|
for image_processor_class in self.image_processor_list:
|
||||||
|
image_processor = image_processor_class(
|
||||||
do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False
|
do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False
|
||||||
)
|
)
|
||||||
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
||||||
@ -142,7 +173,8 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
height, width = 512, 512
|
height, width = 512, 512
|
||||||
size = {"height": height, "width": width}
|
size = {"height": height, "width": width}
|
||||||
multiple = 32
|
multiple = 32
|
||||||
image_processor = ZoeDepthImageProcessor(
|
for image_processor_class in self.image_processor_list:
|
||||||
|
image_processor = image_processor_class(
|
||||||
do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False
|
do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False
|
||||||
)
|
)
|
||||||
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
||||||
@ -157,14 +189,18 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
image = np.zeros((height, width, 3))
|
image = np.zeros((height, width, 3))
|
||||||
|
|
||||||
size = {"height": 512, "width": 512}
|
size = {"height": 512, "width": 512}
|
||||||
image_processor = ZoeDepthImageProcessor(do_pad=False, keep_aspect_ratio=True, size=size, ensure_multiple_of=1)
|
for image_processor_class in self.image_processor_list:
|
||||||
|
image_processor = image_processor_class(
|
||||||
|
do_pad=False, keep_aspect_ratio=True, size=size, ensure_multiple_of=1
|
||||||
|
)
|
||||||
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
||||||
|
|
||||||
# As can be seen, the image is resized to the maximum size that fits in the specified size
|
# As can be seen, the image is resized to the maximum size that fits in the specified size
|
||||||
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 670])
|
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 670])
|
||||||
|
|
||||||
# Test `keep_aspect_ratio=False` by turning off all other variables which affect the size
|
# Test `keep_aspect_ratio=False` by turning off all other variables which affect the size
|
||||||
image_processor = ZoeDepthImageProcessor(
|
for image_processor_class in self.image_processor_list:
|
||||||
|
image_processor = image_processor_class(
|
||||||
do_pad=False, keep_aspect_ratio=False, size=size, ensure_multiple_of=1
|
do_pad=False, keep_aspect_ratio=False, size=size, ensure_multiple_of=1
|
||||||
)
|
)
|
||||||
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
||||||
@ -177,10 +213,39 @@ class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
|
|
||||||
size = {"height": 511, "width": 511}
|
size = {"height": 511, "width": 511}
|
||||||
multiple = 32
|
multiple = 32
|
||||||
image_processor = ZoeDepthImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=multiple)
|
for image_processor_class in self.image_processor_list:
|
||||||
|
image_processor = image_processor_class(size=size, keep_aspect_ratio=True, ensure_multiple_of=multiple)
|
||||||
|
|
||||||
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
||||||
|
|
||||||
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672])
|
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672])
|
||||||
self.assertTrue(pixel_values.shape[2] % multiple == 0)
|
self.assertTrue(pixel_values.shape[2] % multiple == 0)
|
||||||
self.assertTrue(pixel_values.shape[3] % multiple == 0)
|
self.assertTrue(pixel_values.shape[3] % multiple == 0)
|
||||||
|
|
||||||
|
# extend this test to check if removal of padding works fine!
|
||||||
|
def test_post_processing_equivalence(self):
|
||||||
|
outputs = self.image_processor_tester.prepare_depth_outputs()
|
||||||
|
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
|
||||||
|
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
|
||||||
|
source_sizes = [outputs.predicted_depth.shape[1:]] * self.image_processor_tester.batch_size
|
||||||
|
target_sizes = [
|
||||||
|
torch.Size([outputs.predicted_depth.shape[1] // 2, *(outputs.predicted_depth.shape[2:])])
|
||||||
|
] * self.image_processor_tester.batch_size
|
||||||
|
|
||||||
|
processed_fast = image_processor_fast.post_process_depth_estimation(
|
||||||
|
outputs,
|
||||||
|
source_sizes=source_sizes,
|
||||||
|
target_sizes=target_sizes,
|
||||||
|
)
|
||||||
|
processed_slow = image_processor_slow.post_process_depth_estimation(
|
||||||
|
outputs,
|
||||||
|
source_sizes=source_sizes,
|
||||||
|
target_sizes=target_sizes,
|
||||||
|
)
|
||||||
|
for pred_fast, pred_slow in zip(processed_fast, processed_slow):
|
||||||
|
depth_fast = pred_fast["predicted_depth"]
|
||||||
|
depth_slow = pred_slow["predicted_depth"]
|
||||||
|
|
||||||
|
torch.testing.assert_close(depth_fast, depth_slow, atol=1e-1, rtol=1e-3)
|
||||||
|
self.assertLessEqual(torch.mean(torch.abs(depth_fast.float() - depth_slow.float())).item(), 5e-3)
|
||||||
|
Loading…
Reference in New Issue
Block a user