mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add post_process_depth_estimation to image processors and support ZoeDepth's inference intricacies (#32550)
* add colorize_depth and matplotlib availability check * add post_process_depth_estimation for zoedepth + tests * add post_process_depth_estimation for DPT + tests * add post_process_depth_estimation in DepthEstimationPipeline & special case for zoedepth * run `make fixup` * fix import related error on tests * fix more import related errors on test * forgot some `torch` calls in declerations * remove `torch` call in zoedepth tests that caused error * updated docs for depth estimation * small fix for `colorize` input/output types * remove `colorize_depth`, fix various names, remove matplotlib dependency * fix formatting * run fixup * different images for test * update examples in `forward` functions * fixed broken links * fix output types for docs * possible format fix inside `<Tip>` * Readability related updates Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Readability related update * cleanup after merge * refactor `post_process_depth_estimation` to return dict; simplify ZoeDepth's `post_process_depth_estimation` * rewrite dict merging to support python 3.8 --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
parent
104599d7a8
commit
c31a6ff474
@ -84,27 +84,24 @@ If you want to do the pre- and postprocessing yourself, here's how to do that:
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
... predicted_depth = outputs.predicted_depth
|
||||
|
||||
>>> # interpolate to original size
|
||||
>>> prediction = torch.nn.functional.interpolate(
|
||||
... predicted_depth.unsqueeze(1),
|
||||
... size=image.size[::-1],
|
||||
... mode="bicubic",
|
||||
... align_corners=False,
|
||||
>>> # interpolate to original size and visualize the prediction
|
||||
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||
... outputs,
|
||||
... target_sizes=[(image.height, image.width)],
|
||||
... )
|
||||
|
||||
>>> # visualize the prediction
|
||||
>>> output = prediction.squeeze().cpu().numpy()
|
||||
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
|
||||
>>> depth = Image.fromarray(formatted)
|
||||
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||
>>> depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
|
||||
>>> depth = depth.detach().cpu().numpy() * 255
|
||||
>>> depth = Image.fromarray(depth.astype("uint8"))
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Depth Anything.
|
||||
|
||||
- [Monocular depth estimation task guide](../tasks/depth_estimation)
|
||||
- [Monocular depth estimation task guide](../tasks/monocular_depth_estimation)
|
||||
- A notebook showcasing inference with [`DepthAnythingForDepthEstimation`] can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Depth%20Anything/Predicting_depth_in_an_image_with_Depth_Anything.ipynb). 🌎
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
@ -78,27 +78,24 @@ If you want to do the pre- and post-processing yourself, here's how to do that:
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
... predicted_depth = outputs.predicted_depth
|
||||
|
||||
>>> # interpolate to original size
|
||||
>>> prediction = torch.nn.functional.interpolate(
|
||||
... predicted_depth.unsqueeze(1),
|
||||
... size=image.size[::-1],
|
||||
... mode="bicubic",
|
||||
... align_corners=False,
|
||||
>>> # interpolate to original size and visualize the prediction
|
||||
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||
... outputs,
|
||||
... target_sizes=[(image.height, image.width)],
|
||||
... )
|
||||
|
||||
>>> # visualize the prediction
|
||||
>>> output = prediction.squeeze().cpu().numpy()
|
||||
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
|
||||
>>> depth = Image.fromarray(formatted)
|
||||
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||
>>> depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
|
||||
>>> depth = depth.detach().cpu().numpy() * 255
|
||||
>>> depth = Image.fromarray(depth.astype("uint8"))
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Depth Anything.
|
||||
|
||||
- [Monocular depth estimation task guide](../tasks/depth_estimation)
|
||||
- [Monocular depth estimation task guide](../tasks/monocular_depth_estimation)
|
||||
- [Depth Anything V2 demo](https://huggingface.co/spaces/depth-anything/Depth-Anything-V2).
|
||||
- A notebook showcasing inference with [`DepthAnythingForDepthEstimation`] can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Depth%20Anything/Predicting_depth_in_an_image_with_Depth_Anything.ipynb). 🌎
|
||||
- [Core ML conversion of the `small` variant for use on Apple Silicon](https://huggingface.co/apple/coreml-depth-anything-v2-small).
|
||||
|
@ -39,54 +39,66 @@ The original code can be found [here](https://github.com/isl-org/ZoeDepth).
|
||||
The easiest to perform inference with ZoeDepth is by leveraging the [pipeline API](../main_classes/pipelines.md):
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
from PIL import Image
|
||||
import requests
|
||||
>>> from transformers import pipeline
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
pipe = pipeline(task="depth-estimation", model="Intel/zoedepth-nyu-kitti")
|
||||
result = pipe(image)
|
||||
depth = result["depth"]
|
||||
>>> pipe = pipeline(task="depth-estimation", model="Intel/zoedepth-nyu-kitti")
|
||||
>>> result = pipe(image)
|
||||
>>> depth = result["depth"]
|
||||
```
|
||||
|
||||
Alternatively, one can also perform inference using the classes:
|
||||
|
||||
```python
|
||||
from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
>>> from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
|
||||
>>> import torch
|
||||
>>> import numpy as np
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
image_processor = AutoImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti")
|
||||
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti")
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti")
|
||||
>>> model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti")
|
||||
|
||||
# prepare image for the model
|
||||
inputs = image_processor(images=image, return_tensors="pt")
|
||||
>>> # prepare image for the model
|
||||
>>> inputs = image_processor(images=image, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
predicted_depth = outputs.predicted_depth
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(pixel_values)
|
||||
|
||||
# interpolate to original size
|
||||
prediction = torch.nn.functional.interpolate(
|
||||
predicted_depth.unsqueeze(1),
|
||||
size=image.size[::-1],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
>>> # interpolate to original size and visualize the prediction
|
||||
>>> ## ZoeDepth dynamically pads the input image. Thus we pass the original image size as argument
|
||||
>>> ## to `post_process_depth_estimation` to remove the padding and resize to original dimensions.
|
||||
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||
... outputs,
|
||||
... source_sizes=[(image.height, image.width)],
|
||||
... )
|
||||
|
||||
# visualize the prediction
|
||||
output = prediction.squeeze().cpu().numpy()
|
||||
formatted = (output * 255 / np.max(output)).astype("uint8")
|
||||
depth = Image.fromarray(formatted)
|
||||
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||
>>> depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
|
||||
>>> depth = depth.detach().cpu().numpy() * 255
|
||||
>>> depth = Image.fromarray(depth.astype("uint8"))
|
||||
```
|
||||
|
||||
<Tip>
|
||||
<p>In the <a href="https://github.com/isl-org/ZoeDepth/blob/edb6daf45458569e24f50250ef1ed08c015f17a7/zoedepth/models/depth_model.py#L131">original implementation</a> ZoeDepth model performs inference on both the original and flipped images and averages out the results. The <code>post_process_depth_estimation</code> function can handle this for us by passing the flipped outputs to the optional <code>outputs_flipped</code> argument:</p>
|
||||
<pre><code class="language-Python">>>> with torch.no_grad():
|
||||
... outputs = model(pixel_values)
|
||||
... outputs_flipped = model(pixel_values=torch.flip(inputs.pixel_values, dims=[3]))
|
||||
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||
... outputs,
|
||||
... source_sizes=[(image.height, image.width)],
|
||||
... outputs_flipped=outputs_flipped,
|
||||
... )
|
||||
</code></pre>
|
||||
</Tip>
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ZoeDepth.
|
||||
|
@ -126,97 +126,34 @@ Pass the prepared inputs through the model:
|
||||
... outputs = model(pixel_values)
|
||||
```
|
||||
|
||||
Let's post-process and visualize the results.
|
||||
|
||||
We need to pad and then resize the outputs so that predicted depth map has the same dimension as the original image. After resizing we will remove the padded regions from the depth.
|
||||
Let's post-process the results to remove any padding and resize the depth map to match the original image size. The `post_process_depth_estimation` outputs a list of dicts containing the `"predicted_depth"`.
|
||||
|
||||
```py
|
||||
>>> import numpy as np
|
||||
>>> import torch.nn.functional as F
|
||||
>>> # ZoeDepth dynamically pads the input image. Thus we pass the original image size as argument
|
||||
>>> # to `post_process_depth_estimation` to remove the padding and resize to original dimensions.
|
||||
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||
... outputs,
|
||||
... source_sizes=[(image.height, image.width)],
|
||||
... )
|
||||
|
||||
>>> predicted_depth = outputs.predicted_depth.unsqueeze(dim=1)
|
||||
>>> height, width = pixel_values.shape[2:]
|
||||
|
||||
>>> height_padding_factor = width_padding_factor = 3
|
||||
>>> pad_h = int(np.sqrt(height/2) * height_padding_factor)
|
||||
>>> pad_w = int(np.sqrt(width/2) * width_padding_factor)
|
||||
|
||||
>>> if predicted_depth.shape[-2:] != pixel_values.shape[-2:]:
|
||||
>>> predicted_depth = F.interpolate(predicted_depth, size= (height, width), mode='bicubic', align_corners=False)
|
||||
|
||||
>>> if pad_h > 0:
|
||||
predicted_depth = predicted_depth[:, :, pad_h:-pad_h,:]
|
||||
>>> if pad_w > 0:
|
||||
predicted_depth = predicted_depth[:, :, :, pad_w:-pad_w]
|
||||
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||
>>> depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min())
|
||||
>>> depth = depth.detach().cpu().numpy() * 255
|
||||
>>> depth = Image.fromarray(depth.astype("uint8"))
|
||||
```
|
||||
|
||||
We can now visualize the results (the function below is taken from the [GaussianObject](https://github.com/GaussianObject/GaussianObject/blob/ad6629efadb57902d5f8bc0fa562258029a4bdf1/pred_monodepth.py#L11) framework).
|
||||
|
||||
```py
|
||||
import matplotlib
|
||||
|
||||
def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
|
||||
"""Converts a depth map to a color image.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor, numpy.ndarray): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
|
||||
vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
|
||||
vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
|
||||
cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
|
||||
invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
|
||||
invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
|
||||
background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
|
||||
gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
|
||||
value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
|
||||
"""
|
||||
if isinstance(value, torch.Tensor):
|
||||
value = value.detach().cpu().numpy()
|
||||
|
||||
value = value.squeeze()
|
||||
if invalid_mask is None:
|
||||
invalid_mask = value == invalid_val
|
||||
mask = np.logical_not(invalid_mask)
|
||||
|
||||
# normalize
|
||||
vmin = np.percentile(value[mask],2) if vmin is None else vmin
|
||||
vmax = np.percentile(value[mask],85) if vmax is None else vmax
|
||||
if vmin != vmax:
|
||||
value = (value - vmin) / (vmax - vmin) # vmin..vmax
|
||||
else:
|
||||
# Avoid 0-division
|
||||
value = value * 0.
|
||||
|
||||
# squeeze last dim if it exists
|
||||
# grey out the invalid values
|
||||
|
||||
value[invalid_mask] = np.nan
|
||||
cmapper = matplotlib.colormaps.get_cmap(cmap)
|
||||
if value_transform:
|
||||
value = value_transform(value)
|
||||
# value = value / value.max()
|
||||
value = cmapper(value, bytes=True) # (nxmx4)
|
||||
|
||||
# img = value[:, :, :]
|
||||
img = value[...]
|
||||
img[invalid_mask] = background_color
|
||||
|
||||
# return img.transpose((2, 0, 1))
|
||||
if gamma_corrected:
|
||||
# gamma correction
|
||||
img = img / 255
|
||||
img = np.power(img, 2.2)
|
||||
img = img * 255
|
||||
img = img.astype(np.uint8)
|
||||
return img
|
||||
|
||||
>>> result = colorize(predicted_depth.cpu().squeeze().numpy())
|
||||
>>> Image.fromarray(result)
|
||||
```
|
||||
|
||||
|
||||
<Tip>
|
||||
<p>In the <a href="https://github.com/isl-org/ZoeDepth/blob/edb6daf45458569e24f50250ef1ed08c015f17a7/zoedepth/models/depth_model.py#L131">original implementation</a> ZoeDepth model performs inference on both the original and flipped images and averages out the results. The <code>post_process_depth_estimation</code> function can handle this for us by passing the flipped outputs to the optional <code>outputs_flipped</code> argument:</p>
|
||||
<pre><code class="language-Python">>>> with torch.no_grad():
|
||||
... outputs = model(pixel_values)
|
||||
... outputs_flipped = model(pixel_values=torch.flip(inputs.pixel_values, dims=[3]))
|
||||
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||
... outputs,
|
||||
... source_sizes=[(image.height, image.width)],
|
||||
... outputs_flipped=outputs_flipped,
|
||||
... )
|
||||
</code></pre>
|
||||
</Tip>
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/depth-visualization-zoe.png" alt="Depth estimation visualization"/>
|
||||
|
@ -413,20 +413,18 @@ class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel):
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
... predicted_depth = outputs.predicted_depth
|
||||
|
||||
>>> # interpolate to original size
|
||||
>>> prediction = torch.nn.functional.interpolate(
|
||||
... predicted_depth.unsqueeze(1),
|
||||
... size=image.size[::-1],
|
||||
... mode="bicubic",
|
||||
... align_corners=False,
|
||||
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||
... outputs,
|
||||
... target_sizes=[(image.height, image.width)],
|
||||
... )
|
||||
|
||||
>>> # visualize the prediction
|
||||
>>> output = prediction.squeeze().cpu().numpy()
|
||||
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
|
||||
>>> depth = Image.fromarray(formatted)
|
||||
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||
>>> depth = predicted_depth * 255 / predicted_depth.max()
|
||||
>>> depth = depth.detach().cpu().numpy()
|
||||
>>> depth = Image.fromarray(depth.astype("uint8"))
|
||||
```"""
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -15,7 +15,11 @@
|
||||
"""Image processor class for DPT."""
|
||||
|
||||
import math
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...modeling_outputs import DepthEstimatorOutput
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -37,7 +41,13 @@ from ...image_utils import (
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
filter_out_non_signature_kwargs,
|
||||
is_vision_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -461,3 +471,44 @@ class DPTImageProcessor(BaseImageProcessor):
|
||||
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
|
||||
|
||||
return semantic_segmentation
|
||||
|
||||
def post_process_depth_estimation(
|
||||
self,
|
||||
outputs: "DepthEstimatorOutput",
|
||||
target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
|
||||
) -> List[Dict[str, TensorType]]:
|
||||
"""
|
||||
Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
|
||||
Only supports PyTorch.
|
||||
|
||||
Args:
|
||||
outputs ([`DepthEstimatorOutput`]):
|
||||
Raw outputs of the model.
|
||||
target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||
(height, width) of each image in the batch. If left to None, predictions will not be resized.
|
||||
|
||||
Returns:
|
||||
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
|
||||
predictions.
|
||||
"""
|
||||
requires_backends(self, "torch")
|
||||
|
||||
predicted_depth = outputs.predicted_depth
|
||||
|
||||
if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
|
||||
)
|
||||
|
||||
results = []
|
||||
target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
|
||||
for depth, target_size in zip(predicted_depth, target_sizes):
|
||||
if target_size is not None:
|
||||
depth = torch.nn.functional.interpolate(
|
||||
depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
|
||||
).squeeze()
|
||||
|
||||
results.append({"predicted_depth": depth})
|
||||
|
||||
return results
|
||||
|
@ -1121,20 +1121,18 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
... predicted_depth = outputs.predicted_depth
|
||||
|
||||
>>> # interpolate to original size
|
||||
>>> prediction = torch.nn.functional.interpolate(
|
||||
... predicted_depth.unsqueeze(1),
|
||||
... size=image.size[::-1],
|
||||
... mode="bicubic",
|
||||
... align_corners=False,
|
||||
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||
... outputs,
|
||||
... target_sizes=[(image.height, image.width)],
|
||||
... )
|
||||
|
||||
>>> # visualize the prediction
|
||||
>>> output = prediction.squeeze().cpu().numpy()
|
||||
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
|
||||
>>> depth = Image.fromarray(formatted)
|
||||
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||
>>> depth = predicted_depth * 255 / predicted_depth.max()
|
||||
>>> depth = depth.detach().cpu().numpy()
|
||||
>>> depth = Image.fromarray(depth.astype("uint8"))
|
||||
```"""
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -15,10 +15,14 @@
|
||||
"""Image processor class for ZoeDepth."""
|
||||
|
||||
import math
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .modeling_zoedepth import ZoeDepthDepthEstimatorOutput
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import PaddingMode, pad, to_channel_dimension_format
|
||||
from ...image_utils import (
|
||||
@ -126,10 +130,10 @@ class ZoeDepthImageProcessor(BaseImageProcessor):
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
|
||||
keep_aspect_ratio (`bool`, *optional*, defaults to `True`):
|
||||
If `True`, the image is resized by choosing the smaller of the height and width scaling factors and using it for
|
||||
both dimensions. This ensures that the image is scaled down as little as possible while still fitting within the
|
||||
desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a size that is a
|
||||
multiple of this value by flooring the height and width to the nearest multiple of this value.
|
||||
If `True`, the image is resized by choosing the smaller of the height and width scaling factors and using it
|
||||
for both dimensions. This ensures that the image is scaled down as little as possible while still fitting
|
||||
within the desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a
|
||||
size that is a multiple of this value by flooring the height and width to the nearest multiple of this value.
|
||||
Can be overidden by `keep_aspect_ratio` in `preprocess`.
|
||||
ensure_multiple_of (`int`, *optional*, defaults to 32):
|
||||
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by flooring
|
||||
@ -331,19 +335,21 @@ class ZoeDepthImageProcessor(BaseImageProcessor):
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. If `keep_aspect_ratio` is `True`, he image is resized by choosing the smaller of
|
||||
the height and width scaling factors and using it for both dimensions. If `ensure_multiple_of` is also set,
|
||||
the image is further resized to a size that is a multiple of this value.
|
||||
Size of the image after resizing. If `keep_aspect_ratio` is `True`, he image is resized by choosing the
|
||||
smaller of the height and width scaling factors and using it for both dimensions. If `ensure_multiple_of`
|
||||
is also set, the image is further resized to a size that is a multiple of this value.
|
||||
keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`):
|
||||
If `True` and `do_resize=True`, the image is resized by choosing the smaller of the height and width scaling factors and using it for
|
||||
both dimensions. This ensures that the image is scaled down as little as possible while still fitting within the
|
||||
desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a size that is a
|
||||
multiple of this value by flooring the height and width to the nearest multiple of this value.
|
||||
If `True` and `do_resize=True`, the image is resized by choosing the smaller of the height and width
|
||||
scaling factors and using it for both dimensions. This ensures that the image is scaled down as little
|
||||
as possible while still fitting within the desired output size. In case `ensure_multiple_of` is also
|
||||
set, the image is further resized to a size that is a multiple of this value by flooring the height and
|
||||
width to the nearest multiple of this value.
|
||||
ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`):
|
||||
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by flooring
|
||||
the height and width to the nearest multiple of this value.
|
||||
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by
|
||||
flooring the height and width to the nearest multiple of this value.
|
||||
|
||||
Works both with and without `keep_aspect_ratio` being set to `True`. Can be overidden by `ensure_multiple_of` in `preprocess`.
|
||||
Works both with and without `keep_aspect_ratio` being set to `True`. Can be overidden by
|
||||
`ensure_multiple_of` in `preprocess`.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
@ -442,3 +448,111 @@ class ZoeDepthImageProcessor(BaseImageProcessor):
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
def post_process_depth_estimation(
|
||||
self,
|
||||
outputs: "ZoeDepthDepthEstimatorOutput",
|
||||
source_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
|
||||
target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
|
||||
outputs_flipped: Optional[Union["ZoeDepthDepthEstimatorOutput", None]] = None,
|
||||
do_remove_padding: Optional[Union[bool, None]] = None,
|
||||
) -> List[Dict[str, TensorType]]:
|
||||
"""
|
||||
Converts the raw output of [`ZoeDepthDepthEstimatorOutput`] into final depth predictions and depth PIL images.
|
||||
Only supports PyTorch.
|
||||
|
||||
Args:
|
||||
outputs ([`ZoeDepthDepthEstimatorOutput`]):
|
||||
Raw outputs of the model.
|
||||
source_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the source size
|
||||
(height, width) of each image in the batch before preprocessing. This argument should be dealt as
|
||||
"required" unless the user passes `do_remove_padding=False` as input to this function.
|
||||
target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||
(height, width) of each image in the batch. If left to None, predictions will not be resized.
|
||||
outputs_flipped ([`ZoeDepthDepthEstimatorOutput`], *optional*):
|
||||
Raw outputs of the model from flipped input (averaged out in the end).
|
||||
do_remove_padding (`bool`, *optional*):
|
||||
By default ZoeDepth addes padding equal to `int(√(height / 2) * 3)` (and similarly for width) to fix the
|
||||
boundary artifacts in the output depth map, so we need remove this padding during post_processing. The
|
||||
parameter exists here in case the user changed the image preprocessing to not include padding.
|
||||
|
||||
Returns:
|
||||
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
|
||||
predictions.
|
||||
"""
|
||||
requires_backends(self, "torch")
|
||||
|
||||
predicted_depth = outputs.predicted_depth
|
||||
|
||||
if (outputs_flipped is not None) and (predicted_depth.shape != outputs_flipped.predicted_depth.shape):
|
||||
raise ValueError("Make sure that `outputs` and `outputs_flipped` have the same shape")
|
||||
|
||||
if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
|
||||
)
|
||||
|
||||
if do_remove_padding is None:
|
||||
do_remove_padding = self.do_pad
|
||||
|
||||
if source_sizes is None and do_remove_padding:
|
||||
raise ValueError(
|
||||
"Either `source_sizes` should be passed in, or `do_remove_padding` should be set to False"
|
||||
)
|
||||
|
||||
if (source_sizes is not None) and (len(predicted_depth) != len(source_sizes)):
|
||||
raise ValueError(
|
||||
"Make sure that you pass in as many source image sizes as the batch dimension of the logits"
|
||||
)
|
||||
|
||||
if outputs_flipped is not None:
|
||||
predicted_depth = (predicted_depth + torch.flip(outputs_flipped.predicted_depth, dims=[-1])) / 2
|
||||
|
||||
predicted_depth = predicted_depth.unsqueeze(1)
|
||||
|
||||
# Zoe Depth model adds padding around the images to fix the boundary artifacts in the output depth map
|
||||
# The padding length is `int(np.sqrt(img_h/2) * fh)` for the height and similar for the width
|
||||
# fh (and fw respectively) are equal to '3' by default
|
||||
# Check [here](https://github.com/isl-org/ZoeDepth/blob/edb6daf45458569e24f50250ef1ed08c015f17a7/zoedepth/models/depth_model.py#L57)
|
||||
# for the original implementation.
|
||||
# In this section, we remove this padding to get the final depth image and depth prediction
|
||||
padding_factor_h = padding_factor_w = 3
|
||||
|
||||
results = []
|
||||
target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
|
||||
source_sizes = [None] * len(predicted_depth) if source_sizes is None else source_sizes
|
||||
for depth, target_size, source_size in zip(predicted_depth, target_sizes, source_sizes):
|
||||
# depth.shape = [1, H, W]
|
||||
if source_size is not None:
|
||||
pad_h = pad_w = 0
|
||||
|
||||
if do_remove_padding:
|
||||
pad_h = int(np.sqrt(source_size[0] / 2) * padding_factor_h)
|
||||
pad_w = int(np.sqrt(source_size[1] / 2) * padding_factor_w)
|
||||
|
||||
depth = nn.functional.interpolate(
|
||||
depth.unsqueeze(1),
|
||||
size=[source_size[0] + 2 * pad_h, source_size[1] + 2 * pad_w],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
if pad_h > 0:
|
||||
depth = depth[:, :, pad_h:-pad_h, :]
|
||||
if pad_w > 0:
|
||||
depth = depth[:, :, :, pad_w:-pad_w]
|
||||
|
||||
depth = depth.squeeze(1)
|
||||
# depth.shape = [1, H, W]
|
||||
if target_size is not None:
|
||||
target_size = [target_size[0], target_size[1]]
|
||||
depth = nn.functional.interpolate(
|
||||
depth.unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
|
||||
)
|
||||
depth = depth.squeeze()
|
||||
# depth.shape = [H, W]
|
||||
results.append({"predicted_depth": depth})
|
||||
|
||||
return results
|
||||
|
@ -1338,20 +1338,18 @@ class ZoeDepthForDepthEstimation(ZoeDepthPreTrainedModel):
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
... predicted_depth = outputs.predicted_depth
|
||||
|
||||
>>> # interpolate to original size
|
||||
>>> prediction = torch.nn.functional.interpolate(
|
||||
... predicted_depth.unsqueeze(1),
|
||||
... size=image.size[::-1],
|
||||
... mode="bicubic",
|
||||
... align_corners=False,
|
||||
>>> post_processed_output = image_processor.post_process_depth_estimation(
|
||||
... outputs,
|
||||
... source_sizes=[(image.height, image.width)],
|
||||
... )
|
||||
|
||||
>>> # visualize the prediction
|
||||
>>> output = prediction.squeeze().cpu().numpy()
|
||||
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
|
||||
>>> depth = Image.fromarray(formatted)
|
||||
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
|
||||
>>> depth = predicted_depth * 255 / predicted_depth.max()
|
||||
>>> depth = depth.detach().cpu().numpy()
|
||||
>>> depth = Image.fromarray(depth.astype("uint8"))
|
||||
```"""
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
@ -1,9 +1,13 @@
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
|
||||
from ..utils import (
|
||||
add_end_docstrings,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
from .base import Pipeline, build_pipeline_init_args
|
||||
|
||||
|
||||
@ -13,8 +17,6 @@ if is_vision_available():
|
||||
from ..image_utils import load_image
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -114,14 +116,19 @@ class DepthEstimationPipeline(Pipeline):
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
predicted_depth = model_outputs.predicted_depth
|
||||
prediction = torch.nn.functional.interpolate(
|
||||
predicted_depth.unsqueeze(1), size=model_outputs["target_size"], mode="bicubic", align_corners=False
|
||||
outputs = self.image_processor.post_process_depth_estimation(
|
||||
model_outputs,
|
||||
# this acts as `source_sizes` for ZoeDepth and as `target_sizes` for the rest of the models so do *not*
|
||||
# replace with `target_sizes = [model_outputs["target_size"]]`
|
||||
[model_outputs["target_size"]],
|
||||
)
|
||||
output = prediction.squeeze().cpu().numpy()
|
||||
formatted = (output * 255 / np.max(output)).astype("uint8")
|
||||
depth = Image.fromarray(formatted)
|
||||
output_dict = {}
|
||||
output_dict["predicted_depth"] = predicted_depth
|
||||
output_dict["depth"] = depth
|
||||
return output_dict
|
||||
|
||||
formatted_outputs = []
|
||||
for output in outputs:
|
||||
depth = output["predicted_depth"].detach().cpu().numpy()
|
||||
depth = (depth - depth.min()) / (depth.max() - depth.min())
|
||||
depth = Image.fromarray((depth * 255).astype("uint8"))
|
||||
|
||||
formatted_outputs.append({"predicted_depth": output["predicted_depth"], "depth": depth})
|
||||
|
||||
return formatted_outputs[0] if len(outputs) == 1 else formatted_outputs
|
||||
|
@ -384,3 +384,29 @@ class DPTModelIntegrationTest(unittest.TestCase):
|
||||
segmentation = image_processor.post_process_semantic_segmentation(outputs=outputs)
|
||||
expected_shape = torch.Size((480, 480))
|
||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||
|
||||
def test_post_processing_depth_estimation(self):
|
||||
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
|
||||
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
|
||||
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, return_tensors="pt")
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
predicted_depth = image_processor.post_process_depth_estimation(outputs=outputs)[0]["predicted_depth"]
|
||||
expected_shape = torch.Size((384, 384))
|
||||
self.assertTrue(predicted_depth.shape == expected_shape)
|
||||
|
||||
predicted_depth_l = image_processor.post_process_depth_estimation(outputs=outputs, target_sizes=[(500, 500)])
|
||||
predicted_depth_l = predicted_depth_l[0]["predicted_depth"]
|
||||
expected_shape = torch.Size((500, 500))
|
||||
self.assertTrue(predicted_depth_l.shape == expected_shape)
|
||||
|
||||
output_enlarged = torch.nn.functional.interpolate(
|
||||
predicted_depth.unsqueeze(0).unsqueeze(1), size=(500, 500), mode="bicubic", align_corners=False
|
||||
).squeeze()
|
||||
self.assertTrue(output_enlarged.shape == expected_shape)
|
||||
self.assertTrue(torch.allclose(predicted_depth_l, output_enlarged, rtol=1e-3))
|
||||
|
@ -16,6 +16,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import Dinov2Config, ZoeDepthConfig
|
||||
from transformers.file_utils import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
@ -212,6 +214,25 @@ def prepare_img():
|
||||
@require_vision
|
||||
@slow
|
||||
class ZoeDepthModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice_post_processing = {
|
||||
(False, False): [
|
||||
[[1.1348238, 1.1193453, 1.130562], [1.1754476, 1.1613507, 1.1701596], [1.2287744, 1.2101802, 1.2148322]],
|
||||
[[2.7170, 2.6550, 2.6839], [2.9827, 2.9438, 2.9587], [3.2340, 3.1817, 3.1602]],
|
||||
],
|
||||
(False, True): [
|
||||
[[1.0610938, 1.1042216, 1.1429265], [1.1099341, 1.148696, 1.1817775], [1.1656011, 1.1988826, 1.2268101]],
|
||||
[[2.5848, 2.7391, 2.8694], [2.7882, 2.9872, 3.1244], [2.9436, 3.1812, 3.3188]],
|
||||
],
|
||||
(True, False): [
|
||||
[[1.8382794, 1.8380532, 1.8375976], [1.848761, 1.8485023, 1.8479986], [1.8571457, 1.8568444, 1.8562847]],
|
||||
[[6.2030, 6.1902, 6.1777], [6.2303, 6.2176, 6.2053], [6.2561, 6.2436, 6.2312]],
|
||||
],
|
||||
(True, True): [
|
||||
[[1.8306141, 1.8305621, 1.8303483], [1.8410318, 1.8409299, 1.8406585], [1.8492792, 1.8491366, 1.8488203]],
|
||||
[[6.2616, 6.2520, 6.2435], [6.2845, 6.2751, 6.2667], [6.3065, 6.2972, 6.2887]],
|
||||
],
|
||||
} # (pad, flip)
|
||||
|
||||
def test_inference_depth_estimation(self):
|
||||
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu")
|
||||
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu").to(torch_device)
|
||||
@ -255,3 +276,81 @@ class ZoeDepthModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
def check_target_size(
|
||||
self,
|
||||
image_processor,
|
||||
pad_input,
|
||||
images,
|
||||
outputs,
|
||||
raw_outputs,
|
||||
raw_outputs_flipped=None,
|
||||
):
|
||||
outputs_large = image_processor.post_process_depth_estimation(
|
||||
raw_outputs,
|
||||
[img.size[::-1] for img in images],
|
||||
outputs_flipped=raw_outputs_flipped,
|
||||
target_sizes=[tuple(np.array(img.size[::-1]) * 2) for img in images],
|
||||
do_remove_padding=pad_input,
|
||||
)
|
||||
|
||||
for img, out, out_l in zip(images, outputs, outputs_large):
|
||||
out = out["predicted_depth"]
|
||||
out_l = out_l["predicted_depth"]
|
||||
out_l_reduced = torch.nn.functional.interpolate(
|
||||
out_l.unsqueeze(0).unsqueeze(1), size=img.size[::-1], mode="bicubic", align_corners=False
|
||||
)
|
||||
self.assertTrue((np.array(out_l.shape)[::-1] == np.array(img.size) * 2).all())
|
||||
self.assertTrue(torch.allclose(out, out_l_reduced, rtol=2e-2))
|
||||
|
||||
def check_post_processing_test(self, image_processor, images, model, pad_input=True, flip_aug=True):
|
||||
inputs = image_processor(images=images, return_tensors="pt", do_pad=pad_input).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
raw_outputs = model(**inputs)
|
||||
raw_outputs_flipped = None
|
||||
if flip_aug:
|
||||
raw_outputs_flipped = model(pixel_values=torch.flip(inputs.pixel_values, dims=[3]))
|
||||
|
||||
outputs = image_processor.post_process_depth_estimation(
|
||||
raw_outputs,
|
||||
[img.size[::-1] for img in images],
|
||||
outputs_flipped=raw_outputs_flipped,
|
||||
do_remove_padding=pad_input,
|
||||
)
|
||||
|
||||
expected_slices = torch.tensor(self.expected_slice_post_processing[pad_input, flip_aug]).to(torch_device)
|
||||
for img, out, expected_slice in zip(images, outputs, expected_slices):
|
||||
out = out["predicted_depth"]
|
||||
self.assertTrue(img.size == out.shape[::-1])
|
||||
self.assertTrue(torch.allclose(expected_slice, out[:3, :3], rtol=1e-3))
|
||||
|
||||
self.check_target_size(image_processor, pad_input, images, outputs, raw_outputs, raw_outputs_flipped)
|
||||
|
||||
def test_post_processing_depth_estimation_post_processing_nopad_noflip(self):
|
||||
images = [prepare_img(), Image.open("./tests/fixtures/tests_samples/COCO/000000004016.png")]
|
||||
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti", keep_aspect_ratio=False)
|
||||
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti").to(torch_device)
|
||||
|
||||
self.check_post_processing_test(image_processor, images, model, pad_input=False, flip_aug=False)
|
||||
|
||||
def test_inference_depth_estimation_post_processing_nopad_flip(self):
|
||||
images = [prepare_img(), Image.open("./tests/fixtures/tests_samples/COCO/000000004016.png")]
|
||||
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti", keep_aspect_ratio=False)
|
||||
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti").to(torch_device)
|
||||
|
||||
self.check_post_processing_test(image_processor, images, model, pad_input=False, flip_aug=True)
|
||||
|
||||
def test_inference_depth_estimation_post_processing_pad_noflip(self):
|
||||
images = [prepare_img(), Image.open("./tests/fixtures/tests_samples/COCO/000000004016.png")]
|
||||
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti", keep_aspect_ratio=False)
|
||||
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti").to(torch_device)
|
||||
|
||||
self.check_post_processing_test(image_processor, images, model, pad_input=True, flip_aug=False)
|
||||
|
||||
def test_inference_depth_estimation_post_processing_pad_flip(self):
|
||||
images = [prepare_img(), Image.open("./tests/fixtures/tests_samples/COCO/000000004016.png")]
|
||||
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti", keep_aspect_ratio=False)
|
||||
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti").to(torch_device)
|
||||
|
||||
self.check_post_processing_test(image_processor, images, model, pad_input=True, flip_aug=True)
|
||||
|
@ -129,7 +129,7 @@ class DepthEstimationPipelineTests(unittest.TestCase):
|
||||
|
||||
# This seems flaky.
|
||||
# self.assertEqual(outputs["depth"], "1a39394e282e9f3b0741a90b9f108977")
|
||||
self.assertEqual(nested_simplify(outputs["predicted_depth"].max().item()), 29.304)
|
||||
self.assertEqual(nested_simplify(outputs["predicted_depth"].max().item()), 29.306)
|
||||
self.assertEqual(nested_simplify(outputs["predicted_depth"].min().item()), 2.662)
|
||||
|
||||
@require_torch
|
||||
|
Loading…
Reference in New Issue
Block a user