Add post_process_depth_estimation to image processors and support ZoeDepth's inference intricacies (#32550)

* add colorize_depth and matplotlib availability check

* add post_process_depth_estimation for zoedepth + tests

* add post_process_depth_estimation for DPT + tests

* add post_process_depth_estimation in DepthEstimationPipeline & special case for zoedepth

* run `make fixup`

* fix import related error on tests

* fix more import related errors on test

* forgot some `torch` calls in declerations

* remove `torch` call in zoedepth tests that caused error

* updated docs for depth estimation

* small fix for `colorize` input/output types

* remove `colorize_depth`, fix various names, remove matplotlib dependency

* fix formatting

* run fixup

* different images for test

* update examples in `forward` functions

* fixed broken links

* fix output types for docs

* possible format fix inside `<Tip>`

* Readability related updates

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Readability related update

* cleanup after merge

* refactor `post_process_depth_estimation` to return dict; simplify ZoeDepth's `post_process_depth_estimation`

* rewrite dict merging to support python 3.8

---------

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
Alexandros Benetatos 2024-10-22 16:50:54 +03:00 committed by GitHub
parent 104599d7a8
commit c31a6ff474
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 437 additions and 203 deletions

View File

@ -84,27 +84,24 @@ If you want to do the pre- and postprocessing yourself, here's how to do that:
>>> with torch.no_grad():
... outputs = model(**inputs)
... predicted_depth = outputs.predicted_depth
>>> # interpolate to original size
>>> prediction = torch.nn.functional.interpolate(
... predicted_depth.unsqueeze(1),
... size=image.size[::-1],
... mode="bicubic",
... align_corners=False,
>>> # interpolate to original size and visualize the prediction
>>> post_processed_output = image_processor.post_process_depth_estimation(
... outputs,
... target_sizes=[(image.height, image.width)],
... )
>>> # visualize the prediction
>>> output = prediction.squeeze().cpu().numpy()
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
>>> depth = Image.fromarray(formatted)
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
>>> depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
>>> depth = depth.detach().cpu().numpy() * 255
>>> depth = Image.fromarray(depth.astype("uint8"))
```
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Depth Anything.
- [Monocular depth estimation task guide](../tasks/depth_estimation)
- [Monocular depth estimation task guide](../tasks/monocular_depth_estimation)
- A notebook showcasing inference with [`DepthAnythingForDepthEstimation`] can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Depth%20Anything/Predicting_depth_in_an_image_with_Depth_Anything.ipynb). 🌎
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

View File

@ -78,27 +78,24 @@ If you want to do the pre- and post-processing yourself, here's how to do that:
>>> with torch.no_grad():
... outputs = model(**inputs)
... predicted_depth = outputs.predicted_depth
>>> # interpolate to original size
>>> prediction = torch.nn.functional.interpolate(
... predicted_depth.unsqueeze(1),
... size=image.size[::-1],
... mode="bicubic",
... align_corners=False,
>>> # interpolate to original size and visualize the prediction
>>> post_processed_output = image_processor.post_process_depth_estimation(
... outputs,
... target_sizes=[(image.height, image.width)],
... )
>>> # visualize the prediction
>>> output = prediction.squeeze().cpu().numpy()
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
>>> depth = Image.fromarray(formatted)
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
>>> depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
>>> depth = depth.detach().cpu().numpy() * 255
>>> depth = Image.fromarray(depth.astype("uint8"))
```
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Depth Anything.
- [Monocular depth estimation task guide](../tasks/depth_estimation)
- [Monocular depth estimation task guide](../tasks/monocular_depth_estimation)
- [Depth Anything V2 demo](https://huggingface.co/spaces/depth-anything/Depth-Anything-V2).
- A notebook showcasing inference with [`DepthAnythingForDepthEstimation`] can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Depth%20Anything/Predicting_depth_in_an_image_with_Depth_Anything.ipynb). 🌎
- [Core ML conversion of the `small` variant for use on Apple Silicon](https://huggingface.co/apple/coreml-depth-anything-v2-small).

View File

@ -39,54 +39,66 @@ The original code can be found [here](https://github.com/isl-org/ZoeDepth).
The easiest to perform inference with ZoeDepth is by leveraging the [pipeline API](../main_classes/pipelines.md):
```python
from transformers import pipeline
from PIL import Image
import requests
>>> from transformers import pipeline
>>> from PIL import Image
>>> import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
pipe = pipeline(task="depth-estimation", model="Intel/zoedepth-nyu-kitti")
result = pipe(image)
depth = result["depth"]
>>> pipe = pipeline(task="depth-estimation", model="Intel/zoedepth-nyu-kitti")
>>> result = pipe(image)
>>> depth = result["depth"]
```
Alternatively, one can also perform inference using the classes:
```python
from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
import torch
import numpy as np
from PIL import Image
import requests
>>> from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
>>> import torch
>>> import numpy as np
>>> from PIL import Image
>>> import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti")
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti")
>>> image_processor = AutoImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti")
>>> model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti")
# prepare image for the model
inputs = image_processor(images=image, return_tensors="pt")
>>> # prepare image for the model
>>> inputs = image_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
>>> with torch.no_grad():
... outputs = model(pixel_values)
# interpolate to original size
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
)
>>> # interpolate to original size and visualize the prediction
>>> ## ZoeDepth dynamically pads the input image. Thus we pass the original image size as argument
>>> ## to `post_process_depth_estimation` to remove the padding and resize to original dimensions.
>>> post_processed_output = image_processor.post_process_depth_estimation(
... outputs,
... source_sizes=[(image.height, image.width)],
... )
# visualize the prediction
output = prediction.squeeze().cpu().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
depth = Image.fromarray(formatted)
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
>>> depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
>>> depth = depth.detach().cpu().numpy() * 255
>>> depth = Image.fromarray(depth.astype("uint8"))
```
<Tip>
<p>In the <a href="https://github.com/isl-org/ZoeDepth/blob/edb6daf45458569e24f50250ef1ed08c015f17a7/zoedepth/models/depth_model.py#L131">original implementation</a> ZoeDepth model performs inference on both the original and flipped images and averages out the results. The <code>post_process_depth_estimation</code> function can handle this for us by passing the flipped outputs to the optional <code>outputs_flipped</code> argument:</p>
<pre><code class="language-Python">&gt;&gt;&gt; with torch.no_grad():
... outputs = model(pixel_values)
... outputs_flipped = model(pixel_values=torch.flip(inputs.pixel_values, dims=[3]))
&gt;&gt;&gt; post_processed_output = image_processor.post_process_depth_estimation(
... outputs,
... source_sizes=[(image.height, image.width)],
... outputs_flipped=outputs_flipped,
... )
</code></pre>
</Tip>
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ZoeDepth.

View File

@ -126,97 +126,34 @@ Pass the prepared inputs through the model:
... outputs = model(pixel_values)
```
Let's post-process and visualize the results.
We need to pad and then resize the outputs so that predicted depth map has the same dimension as the original image. After resizing we will remove the padded regions from the depth.
Let's post-process the results to remove any padding and resize the depth map to match the original image size. The `post_process_depth_estimation` outputs a list of dicts containing the `"predicted_depth"`.
```py
>>> import numpy as np
>>> import torch.nn.functional as F
>>> # ZoeDepth dynamically pads the input image. Thus we pass the original image size as argument
>>> # to `post_process_depth_estimation` to remove the padding and resize to original dimensions.
>>> post_processed_output = image_processor.post_process_depth_estimation(
... outputs,
... source_sizes=[(image.height, image.width)],
... )
>>> predicted_depth = outputs.predicted_depth.unsqueeze(dim=1)
>>> height, width = pixel_values.shape[2:]
>>> height_padding_factor = width_padding_factor = 3
>>> pad_h = int(np.sqrt(height/2) * height_padding_factor)
>>> pad_w = int(np.sqrt(width/2) * width_padding_factor)
>>> if predicted_depth.shape[-2:] != pixel_values.shape[-2:]:
>>> predicted_depth = F.interpolate(predicted_depth, size= (height, width), mode='bicubic', align_corners=False)
>>> if pad_h > 0:
predicted_depth = predicted_depth[:, :, pad_h:-pad_h,:]
>>> if pad_w > 0:
predicted_depth = predicted_depth[:, :, :, pad_w:-pad_w]
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
>>> depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
>>> depth = depth.detach().cpu().numpy() * 255
>>> depth = Image.fromarray(depth.astype("uint8"))
```
We can now visualize the results (the function below is taken from the [GaussianObject](https://github.com/GaussianObject/GaussianObject/blob/ad6629efadb57902d5f8bc0fa562258029a4bdf1/pred_monodepth.py#L11) framework).
```py
import matplotlib
def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
"""Converts a depth map to a color image.
Args:
value (torch.Tensor, numpy.ndarray): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
Returns:
numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
"""
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()
value = value.squeeze()
if invalid_mask is None:
invalid_mask = value == invalid_val
mask = np.logical_not(invalid_mask)
# normalize
vmin = np.percentile(value[mask],2) if vmin is None else vmin
vmax = np.percentile(value[mask],85) if vmax is None else vmax
if vmin != vmax:
value = (value - vmin) / (vmax - vmin) # vmin..vmax
else:
# Avoid 0-division
value = value * 0.
# squeeze last dim if it exists
# grey out the invalid values
value[invalid_mask] = np.nan
cmapper = matplotlib.colormaps.get_cmap(cmap)
if value_transform:
value = value_transform(value)
# value = value / value.max()
value = cmapper(value, bytes=True) # (nxmx4)
# img = value[:, :, :]
img = value[...]
img[invalid_mask] = background_color
# return img.transpose((2, 0, 1))
if gamma_corrected:
# gamma correction
img = img / 255
img = np.power(img, 2.2)
img = img * 255
img = img.astype(np.uint8)
return img
>>> result = colorize(predicted_depth.cpu().squeeze().numpy())
>>> Image.fromarray(result)
```
<Tip>
<p>In the <a href="https://github.com/isl-org/ZoeDepth/blob/edb6daf45458569e24f50250ef1ed08c015f17a7/zoedepth/models/depth_model.py#L131">original implementation</a> ZoeDepth model performs inference on both the original and flipped images and averages out the results. The <code>post_process_depth_estimation</code> function can handle this for us by passing the flipped outputs to the optional <code>outputs_flipped</code> argument:</p>
<pre><code class="language-Python">&gt;&gt;&gt; with torch.no_grad():
... outputs = model(pixel_values)
... outputs_flipped = model(pixel_values=torch.flip(inputs.pixel_values, dims=[3]))
&gt;&gt;&gt; post_processed_output = image_processor.post_process_depth_estimation(
... outputs,
... source_sizes=[(image.height, image.width)],
... outputs_flipped=outputs_flipped,
... )
</code></pre>
</Tip>
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/depth-visualization-zoe.png" alt="Depth estimation visualization"/>

View File

@ -413,20 +413,18 @@ class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel):
>>> with torch.no_grad():
... outputs = model(**inputs)
... predicted_depth = outputs.predicted_depth
>>> # interpolate to original size
>>> prediction = torch.nn.functional.interpolate(
... predicted_depth.unsqueeze(1),
... size=image.size[::-1],
... mode="bicubic",
... align_corners=False,
>>> post_processed_output = image_processor.post_process_depth_estimation(
... outputs,
... target_sizes=[(image.height, image.width)],
... )
>>> # visualize the prediction
>>> output = prediction.squeeze().cpu().numpy()
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
>>> depth = Image.fromarray(formatted)
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
>>> depth = predicted_depth * 255 / predicted_depth.max()
>>> depth = depth.detach().cpu().numpy()
>>> depth = Image.fromarray(depth.astype("uint8"))
```"""
loss = None
if labels is not None:

View File

@ -15,7 +15,11 @@
"""Image processor class for DPT."""
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
if TYPE_CHECKING:
from ...modeling_outputs import DepthEstimatorOutput
import numpy as np
@ -37,7 +41,13 @@ from ...image_utils import (
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
from ...utils import (
TensorType,
filter_out_non_signature_kwargs,
is_vision_available,
logging,
requires_backends,
)
if is_torch_available():
@ -461,3 +471,44 @@ class DPTImageProcessor(BaseImageProcessor):
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
def post_process_depth_estimation(
self,
outputs: "DepthEstimatorOutput",
target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
) -> List[Dict[str, TensorType]]:
"""
Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
Only supports PyTorch.
Args:
outputs ([`DepthEstimatorOutput`]):
Raw outputs of the model.
target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized.
Returns:
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
predictions.
"""
requires_backends(self, "torch")
predicted_depth = outputs.predicted_depth
if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
)
results = []
target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
for depth, target_size in zip(predicted_depth, target_sizes):
if target_size is not None:
depth = torch.nn.functional.interpolate(
depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
).squeeze()
results.append({"predicted_depth": depth})
return results

View File

@ -1121,20 +1121,18 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
>>> with torch.no_grad():
... outputs = model(**inputs)
... predicted_depth = outputs.predicted_depth
>>> # interpolate to original size
>>> prediction = torch.nn.functional.interpolate(
... predicted_depth.unsqueeze(1),
... size=image.size[::-1],
... mode="bicubic",
... align_corners=False,
>>> post_processed_output = image_processor.post_process_depth_estimation(
... outputs,
... target_sizes=[(image.height, image.width)],
... )
>>> # visualize the prediction
>>> output = prediction.squeeze().cpu().numpy()
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
>>> depth = Image.fromarray(formatted)
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
>>> depth = predicted_depth * 255 / predicted_depth.max()
>>> depth = depth.detach().cpu().numpy()
>>> depth = Image.fromarray(depth.astype("uint8"))
```"""
loss = None
if labels is not None:

View File

@ -15,10 +15,14 @@
"""Image processor class for ZoeDepth."""
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
if TYPE_CHECKING:
from .modeling_zoedepth import ZoeDepthDepthEstimatorOutput
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import PaddingMode, pad, to_channel_dimension_format
from ...image_utils import (
@ -126,10 +130,10 @@ class ZoeDepthImageProcessor(BaseImageProcessor):
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
keep_aspect_ratio (`bool`, *optional*, defaults to `True`):
If `True`, the image is resized by choosing the smaller of the height and width scaling factors and using it for
both dimensions. This ensures that the image is scaled down as little as possible while still fitting within the
desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a size that is a
multiple of this value by flooring the height and width to the nearest multiple of this value.
If `True`, the image is resized by choosing the smaller of the height and width scaling factors and using it
for both dimensions. This ensures that the image is scaled down as little as possible while still fitting
within the desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a
size that is a multiple of this value by flooring the height and width to the nearest multiple of this value.
Can be overidden by `keep_aspect_ratio` in `preprocess`.
ensure_multiple_of (`int`, *optional*, defaults to 32):
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by flooring
@ -331,19 +335,21 @@ class ZoeDepthImageProcessor(BaseImageProcessor):
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. If `keep_aspect_ratio` is `True`, he image is resized by choosing the smaller of
the height and width scaling factors and using it for both dimensions. If `ensure_multiple_of` is also set,
the image is further resized to a size that is a multiple of this value.
Size of the image after resizing. If `keep_aspect_ratio` is `True`, he image is resized by choosing the
smaller of the height and width scaling factors and using it for both dimensions. If `ensure_multiple_of`
is also set, the image is further resized to a size that is a multiple of this value.
keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`):
If `True` and `do_resize=True`, the image is resized by choosing the smaller of the height and width scaling factors and using it for
both dimensions. This ensures that the image is scaled down as little as possible while still fitting within the
desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a size that is a
multiple of this value by flooring the height and width to the nearest multiple of this value.
If `True` and `do_resize=True`, the image is resized by choosing the smaller of the height and width
scaling factors and using it for both dimensions. This ensures that the image is scaled down as little
as possible while still fitting within the desired output size. In case `ensure_multiple_of` is also
set, the image is further resized to a size that is a multiple of this value by flooring the height and
width to the nearest multiple of this value.
ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`):
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by flooring
the height and width to the nearest multiple of this value.
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by
flooring the height and width to the nearest multiple of this value.
Works both with and without `keep_aspect_ratio` being set to `True`. Can be overidden by `ensure_multiple_of` in `preprocess`.
Works both with and without `keep_aspect_ratio` being set to `True`. Can be overidden by
`ensure_multiple_of` in `preprocess`.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
@ -442,3 +448,111 @@ class ZoeDepthImageProcessor(BaseImageProcessor):
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
def post_process_depth_estimation(
self,
outputs: "ZoeDepthDepthEstimatorOutput",
source_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
outputs_flipped: Optional[Union["ZoeDepthDepthEstimatorOutput", None]] = None,
do_remove_padding: Optional[Union[bool, None]] = None,
) -> List[Dict[str, TensorType]]:
"""
Converts the raw output of [`ZoeDepthDepthEstimatorOutput`] into final depth predictions and depth PIL images.
Only supports PyTorch.
Args:
outputs ([`ZoeDepthDepthEstimatorOutput`]):
Raw outputs of the model.
source_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the source size
(height, width) of each image in the batch before preprocessing. This argument should be dealt as
"required" unless the user passes `do_remove_padding=False` as input to this function.
target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized.
outputs_flipped ([`ZoeDepthDepthEstimatorOutput`], *optional*):
Raw outputs of the model from flipped input (averaged out in the end).
do_remove_padding (`bool`, *optional*):
By default ZoeDepth addes padding equal to `int((height / 2) * 3)` (and similarly for width) to fix the
boundary artifacts in the output depth map, so we need remove this padding during post_processing. The
parameter exists here in case the user changed the image preprocessing to not include padding.
Returns:
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
predictions.
"""
requires_backends(self, "torch")
predicted_depth = outputs.predicted_depth
if (outputs_flipped is not None) and (predicted_depth.shape != outputs_flipped.predicted_depth.shape):
raise ValueError("Make sure that `outputs` and `outputs_flipped` have the same shape")
if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
)
if do_remove_padding is None:
do_remove_padding = self.do_pad
if source_sizes is None and do_remove_padding:
raise ValueError(
"Either `source_sizes` should be passed in, or `do_remove_padding` should be set to False"
)
if (source_sizes is not None) and (len(predicted_depth) != len(source_sizes)):
raise ValueError(
"Make sure that you pass in as many source image sizes as the batch dimension of the logits"
)
if outputs_flipped is not None:
predicted_depth = (predicted_depth + torch.flip(outputs_flipped.predicted_depth, dims=[-1])) / 2
predicted_depth = predicted_depth.unsqueeze(1)
# Zoe Depth model adds padding around the images to fix the boundary artifacts in the output depth map
# The padding length is `int(np.sqrt(img_h/2) * fh)` for the height and similar for the width
# fh (and fw respectively) are equal to '3' by default
# Check [here](https://github.com/isl-org/ZoeDepth/blob/edb6daf45458569e24f50250ef1ed08c015f17a7/zoedepth/models/depth_model.py#L57)
# for the original implementation.
# In this section, we remove this padding to get the final depth image and depth prediction
padding_factor_h = padding_factor_w = 3
results = []
target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
source_sizes = [None] * len(predicted_depth) if source_sizes is None else source_sizes
for depth, target_size, source_size in zip(predicted_depth, target_sizes, source_sizes):
# depth.shape = [1, H, W]
if source_size is not None:
pad_h = pad_w = 0
if do_remove_padding:
pad_h = int(np.sqrt(source_size[0] / 2) * padding_factor_h)
pad_w = int(np.sqrt(source_size[1] / 2) * padding_factor_w)
depth = nn.functional.interpolate(
depth.unsqueeze(1),
size=[source_size[0] + 2 * pad_h, source_size[1] + 2 * pad_w],
mode="bicubic",
align_corners=False,
)
if pad_h > 0:
depth = depth[:, :, pad_h:-pad_h, :]
if pad_w > 0:
depth = depth[:, :, :, pad_w:-pad_w]
depth = depth.squeeze(1)
# depth.shape = [1, H, W]
if target_size is not None:
target_size = [target_size[0], target_size[1]]
depth = nn.functional.interpolate(
depth.unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
)
depth = depth.squeeze()
# depth.shape = [H, W]
results.append({"predicted_depth": depth})
return results

View File

@ -1338,20 +1338,18 @@ class ZoeDepthForDepthEstimation(ZoeDepthPreTrainedModel):
>>> with torch.no_grad():
... outputs = model(**inputs)
... predicted_depth = outputs.predicted_depth
>>> # interpolate to original size
>>> prediction = torch.nn.functional.interpolate(
... predicted_depth.unsqueeze(1),
... size=image.size[::-1],
... mode="bicubic",
... align_corners=False,
>>> post_processed_output = image_processor.post_process_depth_estimation(
... outputs,
... source_sizes=[(image.height, image.width)],
... )
>>> # visualize the prediction
>>> output = prediction.squeeze().cpu().numpy()
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
>>> depth = Image.fromarray(formatted)
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
>>> depth = predicted_depth * 255 / predicted_depth.max()
>>> depth = depth.detach().cpu().numpy()
>>> depth = Image.fromarray(depth.astype("uint8"))
```"""
loss = None
if labels is not None:

View File

@ -1,9 +1,13 @@
import warnings
from typing import List, Union
import numpy as np
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
from ..utils import (
add_end_docstrings,
is_torch_available,
is_vision_available,
logging,
requires_backends,
)
from .base import Pipeline, build_pipeline_init_args
@ -13,8 +17,6 @@ if is_vision_available():
from ..image_utils import load_image
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
logger = logging.get_logger(__name__)
@ -114,14 +116,19 @@ class DepthEstimationPipeline(Pipeline):
return model_outputs
def postprocess(self, model_outputs):
predicted_depth = model_outputs.predicted_depth
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1), size=model_outputs["target_size"], mode="bicubic", align_corners=False
outputs = self.image_processor.post_process_depth_estimation(
model_outputs,
# this acts as `source_sizes` for ZoeDepth and as `target_sizes` for the rest of the models so do *not*
# replace with `target_sizes = [model_outputs["target_size"]]`
[model_outputs["target_size"]],
)
output = prediction.squeeze().cpu().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
depth = Image.fromarray(formatted)
output_dict = {}
output_dict["predicted_depth"] = predicted_depth
output_dict["depth"] = depth
return output_dict
formatted_outputs = []
for output in outputs:
depth = output["predicted_depth"].detach().cpu().numpy()
depth = (depth - depth.min()) / (depth.max() - depth.min())
depth = Image.fromarray((depth * 255).astype("uint8"))
formatted_outputs.append({"predicted_depth": output["predicted_depth"], "depth": depth})
return formatted_outputs[0] if len(outputs) == 1 else formatted_outputs

View File

@ -384,3 +384,29 @@ class DPTModelIntegrationTest(unittest.TestCase):
segmentation = image_processor.post_process_semantic_segmentation(outputs=outputs)
expected_shape = torch.Size((480, 480))
self.assertEqual(segmentation[0].shape, expected_shape)
def test_post_processing_depth_estimation(self):
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt")
# forward pass
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = image_processor.post_process_depth_estimation(outputs=outputs)[0]["predicted_depth"]
expected_shape = torch.Size((384, 384))
self.assertTrue(predicted_depth.shape == expected_shape)
predicted_depth_l = image_processor.post_process_depth_estimation(outputs=outputs, target_sizes=[(500, 500)])
predicted_depth_l = predicted_depth_l[0]["predicted_depth"]
expected_shape = torch.Size((500, 500))
self.assertTrue(predicted_depth_l.shape == expected_shape)
output_enlarged = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(0).unsqueeze(1), size=(500, 500), mode="bicubic", align_corners=False
).squeeze()
self.assertTrue(output_enlarged.shape == expected_shape)
self.assertTrue(torch.allclose(predicted_depth_l, output_enlarged, rtol=1e-3))

View File

@ -16,6 +16,8 @@
import unittest
import numpy as np
from transformers import Dinov2Config, ZoeDepthConfig
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
@ -212,6 +214,25 @@ def prepare_img():
@require_vision
@slow
class ZoeDepthModelIntegrationTest(unittest.TestCase):
expected_slice_post_processing = {
(False, False): [
[[1.1348238, 1.1193453, 1.130562], [1.1754476, 1.1613507, 1.1701596], [1.2287744, 1.2101802, 1.2148322]],
[[2.7170, 2.6550, 2.6839], [2.9827, 2.9438, 2.9587], [3.2340, 3.1817, 3.1602]],
],
(False, True): [
[[1.0610938, 1.1042216, 1.1429265], [1.1099341, 1.148696, 1.1817775], [1.1656011, 1.1988826, 1.2268101]],
[[2.5848, 2.7391, 2.8694], [2.7882, 2.9872, 3.1244], [2.9436, 3.1812, 3.3188]],
],
(True, False): [
[[1.8382794, 1.8380532, 1.8375976], [1.848761, 1.8485023, 1.8479986], [1.8571457, 1.8568444, 1.8562847]],
[[6.2030, 6.1902, 6.1777], [6.2303, 6.2176, 6.2053], [6.2561, 6.2436, 6.2312]],
],
(True, True): [
[[1.8306141, 1.8305621, 1.8303483], [1.8410318, 1.8409299, 1.8406585], [1.8492792, 1.8491366, 1.8488203]],
[[6.2616, 6.2520, 6.2435], [6.2845, 6.2751, 6.2667], [6.3065, 6.2972, 6.2887]],
],
} # (pad, flip)
def test_inference_depth_estimation(self):
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu")
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu").to(torch_device)
@ -255,3 +276,81 @@ class ZoeDepthModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4))
def check_target_size(
self,
image_processor,
pad_input,
images,
outputs,
raw_outputs,
raw_outputs_flipped=None,
):
outputs_large = image_processor.post_process_depth_estimation(
raw_outputs,
[img.size[::-1] for img in images],
outputs_flipped=raw_outputs_flipped,
target_sizes=[tuple(np.array(img.size[::-1]) * 2) for img in images],
do_remove_padding=pad_input,
)
for img, out, out_l in zip(images, outputs, outputs_large):
out = out["predicted_depth"]
out_l = out_l["predicted_depth"]
out_l_reduced = torch.nn.functional.interpolate(
out_l.unsqueeze(0).unsqueeze(1), size=img.size[::-1], mode="bicubic", align_corners=False
)
self.assertTrue((np.array(out_l.shape)[::-1] == np.array(img.size) * 2).all())
self.assertTrue(torch.allclose(out, out_l_reduced, rtol=2e-2))
def check_post_processing_test(self, image_processor, images, model, pad_input=True, flip_aug=True):
inputs = image_processor(images=images, return_tensors="pt", do_pad=pad_input).to(torch_device)
with torch.no_grad():
raw_outputs = model(**inputs)
raw_outputs_flipped = None
if flip_aug:
raw_outputs_flipped = model(pixel_values=torch.flip(inputs.pixel_values, dims=[3]))
outputs = image_processor.post_process_depth_estimation(
raw_outputs,
[img.size[::-1] for img in images],
outputs_flipped=raw_outputs_flipped,
do_remove_padding=pad_input,
)
expected_slices = torch.tensor(self.expected_slice_post_processing[pad_input, flip_aug]).to(torch_device)
for img, out, expected_slice in zip(images, outputs, expected_slices):
out = out["predicted_depth"]
self.assertTrue(img.size == out.shape[::-1])
self.assertTrue(torch.allclose(expected_slice, out[:3, :3], rtol=1e-3))
self.check_target_size(image_processor, pad_input, images, outputs, raw_outputs, raw_outputs_flipped)
def test_post_processing_depth_estimation_post_processing_nopad_noflip(self):
images = [prepare_img(), Image.open("./tests/fixtures/tests_samples/COCO/000000004016.png")]
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti", keep_aspect_ratio=False)
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti").to(torch_device)
self.check_post_processing_test(image_processor, images, model, pad_input=False, flip_aug=False)
def test_inference_depth_estimation_post_processing_nopad_flip(self):
images = [prepare_img(), Image.open("./tests/fixtures/tests_samples/COCO/000000004016.png")]
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti", keep_aspect_ratio=False)
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti").to(torch_device)
self.check_post_processing_test(image_processor, images, model, pad_input=False, flip_aug=True)
def test_inference_depth_estimation_post_processing_pad_noflip(self):
images = [prepare_img(), Image.open("./tests/fixtures/tests_samples/COCO/000000004016.png")]
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti", keep_aspect_ratio=False)
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti").to(torch_device)
self.check_post_processing_test(image_processor, images, model, pad_input=True, flip_aug=False)
def test_inference_depth_estimation_post_processing_pad_flip(self):
images = [prepare_img(), Image.open("./tests/fixtures/tests_samples/COCO/000000004016.png")]
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti", keep_aspect_ratio=False)
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti").to(torch_device)
self.check_post_processing_test(image_processor, images, model, pad_input=True, flip_aug=True)

View File

@ -129,7 +129,7 @@ class DepthEstimationPipelineTests(unittest.TestCase):
# This seems flaky.
# self.assertEqual(outputs["depth"], "1a39394e282e9f3b0741a90b9f108977")
self.assertEqual(nested_simplify(outputs["predicted_depth"].max().item()), 29.304)
self.assertEqual(nested_simplify(outputs["predicted_depth"].max().item()), 29.306)
self.assertEqual(nested_simplify(outputs["predicted_depth"].min().item()), 2.662)
@require_torch