Add ZoeDepth (#30136)

* First draft

* Add docs

* Clean up code

* Convert model

* Add image processor

* Convert Zoe_K

* More improvements

* Improve variable names and docstrings

* Improve variable names

* Improve variable names

* Replace nn.sequential

* More improvements

* Convert ZoeD_NK

* Fix most tests

* Verify pixel values

* Verify pixel values

* Add squeeze

* Update beit to support arbitrary window sizes

* Improve image processor

* Improve docstring

* Improve beit

* Improve model outputs

* Add figure

* Fix beit

* Update checkpoint

* Fix repo id

* Add _keys_to_ignore_on_load_unexpected

* More improvements

* Address comments

* Address comments

* Address comments

* Address comments

* Rename variable name

* Add backbone_hidden_size

* Vectorize

* Vectorize more

* Address comments

* Clarify docstring

* Remove backbone_hidden_size

* Fix image processor

* Remove print statements

* Remove print statement

* Add integration test

* Address comments

* Address comments

* Address comments

* Address comments

* Add requires_backends

* Clean up

* Simplify conversion script

* Simplify more

* Simplify more

* Simplify more

* Clean up

* Make sure beit is loaded correctly

* Address comment

* Address bin_configurations

* Use bin_configurations

* Convert models, add integration tests

* Fix doc test

* Address comments

* Unify regressor classes

* Clarify arguments

* Improve resize_image

* Add num_relative_features

* Address comment

* [run-slow]beit,data2vec,zoedepth

* [run-slow]beit,data2vec,zoedepth

* Address comments

* Address comment

* Address comment

* Replace nn.TransformerEncoderLayer and nn.TransformerEncoder

* Replace nn.MultiheadAttention

* Add attributes for patch transformer to config

* Add tests for ensure_multiple_of

* Update organization

* Add tests

* [run-slow] beit data2vec

* Update ruff

* [run-slow] beit data2vec

* Add comment

* Improve docstrings, add test

* Fix interpolate_pos_encoding

* Fix slow tests

* Add docstring

* Update src/transformers/models/zoedepth/image_processing_zoedepth.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/zoedepth/image_processing_zoedepth.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Improve tests and docstrings

* Use run_common_tests

* Improve docstrings

* Improve docstrings

* Improve tests

* Improve tests

* Remove print statements

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
NielsRogge 2024-07-08 11:43:33 +02:00 committed by GitHub
parent 1082361a19
commit 06fd7972ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 3360 additions and 76 deletions

View File

@ -667,6 +667,8 @@
title: ViTMSN
- local: model_doc/yolos
title: YOLOS
- local: model_doc/zoedepth
title: ZoeDepth
title: Vision models
- isExpanded: false
sections:

View File

@ -343,5 +343,6 @@ Flax), PyTorch, and/or TensorFlow.
| [XLSR-Wav2Vec2](model_doc/xlsr_wav2vec2) | ✅ | ✅ | ✅ |
| [YOLOS](model_doc/yolos) | ✅ | ❌ | ❌ |
| [YOSO](model_doc/yoso) | ✅ | ❌ | ❌ |
| [ZoeDepth](model_doc/zoedepth) | ✅ | ❌ | ❌ |
<!-- End table-->

View File

@ -0,0 +1,108 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# ZoeDepth
## Overview
The ZoeDepth model was proposed in [ZoeDepth: Zero-shot Transfer by Combining Relative and Metric Depth](https://arxiv.org/abs/2302.12288) by Shariq Farooq Bhat, Reiner Birkl, Diana Wofk, Peter Wonka, Matthias Müller. ZoeDepth extends the [DPT](dpt) framework for metric (also called absolute) depth estimation. ZoeDepth is pre-trained on 12 datasets using relative depth and fine-tuned on two domains (NYU and KITTI) using metric depth. A lightweight head is used with a novel bin adjustment design called metric bins module for each domain. During inference, each input image is automatically routed to the appropriate head using a latent classifier.
The abstract from the paper is the following:
*This paper tackles the problem of depth estimation from a single image. Existing work either focuses on generalization performance disregarding metric scale, i.e. relative depth estimation, or state-of-the-art results on specific datasets, i.e. metric depth estimation. We propose the first approach that combines both worlds, leading to a model with excellent generalization performance while maintaining metric scale. Our flagship model, ZoeD-M12-NK, is pre-trained on 12 datasets using relative depth and fine-tuned on two datasets using metric depth. We use a lightweight head with a novel bin adjustment design called metric bins module for each domain. During inference, each input image is automatically routed to the appropriate head using a latent classifier. Our framework admits multiple configurations depending on the datasets used for relative depth pre-training and metric fine-tuning. Without pre-training, we can already significantly improve the state of the art (SOTA) on the NYU Depth v2 indoor dataset. Pre-training on twelve datasets and fine-tuning on the NYU Depth v2 indoor dataset, we can further improve SOTA for a total of 21% in terms of relative absolute error (REL). Finally, ZoeD-M12-NK is the first model that can jointly train on multiple datasets (NYU Depth v2 and KITTI) without a significant drop in performance and achieve unprecedented zero-shot generalization performance to eight unseen datasets from both indoor and outdoor domains.*
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/zoedepth_architecture_bis.png"
alt="drawing" width="600"/>
<small> ZoeDepth architecture. Taken from the <a href="https://arxiv.org/abs/2302.12288">original paper.</a> </small>
This model was contributed by [nielsr](https://huggingface.co/nielsr).
The original code can be found [here](https://github.com/isl-org/ZoeDepth).
## Usage tips
- ZoeDepth is an absolute (also called metric) depth estimation model, unlike DPT which is a relative depth estimation model. This means that ZoeDepth is able to estimate depth in metric units like meters.
The easiest to perform inference with ZoeDepth is by leveraging the [pipeline API](../main_classes/pipelines.md):
```python
from transformers import pipeline
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
pipe = pipeline(task="depth-estimation", model="Intel/zoedepth-nyu-kitti")
result = pipe(image)
depth = result["depth"]
```
Alternatively, one can also perform inference using the classes:
```python
from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
import torch
import numpy as np
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti")
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti")
# prepare image for the model
inputs = image_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# interpolate to original size
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
)
# visualize the prediction
output = prediction.squeeze().cpu().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
depth = Image.fromarray(formatted)
```
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ZoeDepth.
- A demo notebook regarding inference with ZoeDepth models can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/ZoeDepth). 🌎
## ZoeDepthConfig
[[autodoc]] ZoeDepthConfig
## ZoeDepthImageProcessor
[[autodoc]] ZoeDepthImageProcessor
- preprocess
## ZoeDepthForDepthEstimation
[[autodoc]] ZoeDepthForDepthEstimation
- forward

View File

@ -807,6 +807,7 @@ _import_structure = {
"models.xmod": ["XmodConfig"],
"models.yolos": ["YolosConfig"],
"models.yoso": ["YosoConfig"],
"models.zoedepth": ["ZoeDepthConfig"],
"onnx": [],
"pipelines": [
"AudioClassificationPipeline",
@ -1182,6 +1183,7 @@ else:
_import_structure["models.vitmatte"].append("VitMatteImageProcessor")
_import_structure["models.vivit"].append("VivitImageProcessor")
_import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
_import_structure["models.zoedepth"].append("ZoeDepthImageProcessor")
try:
if not is_torchvision_available():
@ -3586,6 +3588,12 @@ else:
"YosoPreTrainedModel",
]
)
_import_structure["models.zoedepth"].extend(
[
"ZoeDepthForDepthEstimation",
"ZoeDepthPreTrainedModel",
]
)
_import_structure["optimization"] = [
"Adafactor",
"AdamW",
@ -5497,6 +5505,7 @@ if TYPE_CHECKING:
from .models.xmod import XmodConfig
from .models.yolos import YolosConfig
from .models.yoso import YosoConfig
from .models.zoedepth import ZoeDepthConfig
# Pipelines
from .pipelines import (
@ -5872,6 +5881,7 @@ if TYPE_CHECKING:
from .models.vitmatte import VitMatteImageProcessor
from .models.vivit import VivitImageProcessor
from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
from .models.zoedepth import ZoeDepthImageProcessor
try:
if not is_torchvision_available():
@ -7798,6 +7808,10 @@ if TYPE_CHECKING:
YosoModel,
YosoPreTrainedModel,
)
from .models.zoedepth import (
ZoeDepthForDepthEstimation,
ZoeDepthPreTrainedModel,
)
# Optimization
from .optimization import (

View File

@ -409,22 +409,22 @@ def validate_preprocess_arguments(
"""
if do_rescale and rescale_factor is None:
raise ValueError("rescale_factor must be specified if do_rescale is True.")
raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
if do_pad and size_divisibility is None:
# Here, size_divisor might be passed as the value of size
raise ValueError(
"Depending on moel, size_divisibility, size_divisor, pad_size or size must be specified if do_pad is True."
"Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
)
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("image_mean and image_std must both be specified if do_normalize is True.")
raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
if do_center_crop and crop_size is None:
raise ValueError("crop_size must be specified if do_center_crop is True.")
raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
if do_resize and (size is None or resample is None):
raise ValueError("size and resample must be specified if do_resize is True.")
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
# In the future we can add a TF implementation here when we have TF models.

View File

@ -263,4 +263,5 @@ from . import (
xmod,
yolos,
yoso,
zoedepth,
)

View File

@ -291,6 +291,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("xmod", "XmodConfig"),
("yolos", "YolosConfig"),
("yoso", "YosoConfig"),
("zoedepth", "ZoeDepthConfig"),
]
)
@ -589,6 +590,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("xmod", "X-MOD"),
("yolos", "YOLOS"),
("yoso", "YOSO"),
("zoedepth", "ZoeDepth"),
]
)

View File

@ -142,6 +142,7 @@ else:
("vitmatte", ("VitMatteImageProcessor",)),
("xclip", ("CLIPImageProcessor",)),
("yolos", ("YolosImageProcessor",)),
("zoedepth", ("ZoeDepthImageProcessor",)),
]
)

View File

@ -792,6 +792,7 @@ MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
("depth_anything", "DepthAnythingForDepthEstimation"),
("dpt", "DPTForDepthEstimation"),
("glpn", "GLPNForDepthEstimation"),
("zoedepth", "ZoeDepthForDepthEstimation"),
]
)
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(

View File

@ -34,7 +34,7 @@ from ...modeling_outputs import (
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
@ -193,12 +193,6 @@ class BeitEmbeddings(nn.Module):
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings, (patch_height, patch_width) = self.patch_embeddings(
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
)
@ -280,6 +274,7 @@ class BeitPatchEmbeddings(nn.Module):
class BeitSelfAttention(nn.Module):
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
@ -313,6 +308,7 @@ class BeitSelfAttention(nn.Module):
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
@ -327,9 +323,11 @@ class BeitSelfAttention(nn.Module):
# Add relative position bias if present.
if self.relative_position_bias is not None:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
attention_scores = attention_scores + self.relative_position_bias(
interpolate_pos_encoding, attention_scores.shape[2]
).unsqueeze(0)
window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)
# Add shared relative position bias if provided.
if relative_position_bias is not None:
@ -407,9 +405,10 @@ class BeitAttention(nn.Module):
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_outputs = self.attention(
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
)
attention_output = self.output(self_outputs[0], hidden_states)
@ -475,6 +474,7 @@ class BeitLayer(nn.Module):
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
@ -482,6 +482,7 @@ class BeitLayer(nn.Module):
output_attentions=output_attentions,
relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding,
resolution=resolution,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
@ -520,32 +521,71 @@ class BeitRelativePositionBias(nn.Module):
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
self.relative_position_indices = {}
def generate_relative_position_index(self, window_size: Tuple[int, int]) -> torch.Tensor:
"""
This method creates the relative position index, modified to support arbitrary window sizes,
as introduced in [MiDaS v3.1](https://arxiv.org/abs/2307.14460).
"""
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
window_area = window_size[0] * window_size[1]
grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
coords = torch.stack(grid) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
)
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
return relative_position_index
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
"""
Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
"""
old_height = 2 * self.window_size[0] - 1
old_width = 2 * self.window_size[1] - 1
def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
) # Wh*Ww,Wh*Ww,nH
new_height = 2 * window_size[0] - 1
new_width = 2 * window_size[1] - 1
old_relative_position_bias_table = self.relative_position_bias_table
old_num_relative_distance = self.num_relative_distance
new_num_relative_distance = new_height * new_width + 3
old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
new_sub_table = nn.functional.interpolate(
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
)
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
new_relative_position_bias_table = torch.cat(
[new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
)
key = window_size
if key not in self.relative_position_indices.keys():
self.relative_position_indices[key] = self.generate_relative_position_index(window_size)
relative_position_bias = new_relative_position_bias_table[self.relative_position_indices[key].view(-1)]
# patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
relative_position_bias = relative_position_bias.view(
window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
)
# num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
if interpolate_pos_encoding:
relative_position_bias = nn.functional.interpolate(
relative_position_bias.unsqueeze(1),
@ -554,7 +594,7 @@ class BeitRelativePositionBias(nn.Module):
align_corners=False,
).squeeze(1)
return relative_position_bias
return relative_position_bias.unsqueeze(0)
class BeitEncoder(nn.Module):
@ -587,6 +627,7 @@ class BeitEncoder(nn.Module):
output_attentions: bool = False,
output_hidden_states: bool = False,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
@ -606,13 +647,22 @@ class BeitEncoder(nn.Module):
output_attentions,
)
else:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
relative_position_bias = (
self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1])
self.relative_position_bias(
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)
if self.relative_position_bias is not None
else None
)
layer_outputs = layer_module(
hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
hidden_states,
layer_head_mask,
output_attentions,
relative_position_bias,
interpolate_pos_encoding,
resolution,
)
hidden_states = layer_outputs[0]
@ -643,6 +693,7 @@ class BeitPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["BeitLayer"]
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
def _init_weights(self, module):
"""Initialize the weights"""
@ -738,7 +789,7 @@ class BeitModel(BeitPreTrainedModel):
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
@ -756,9 +807,6 @@ class BeitModel(BeitPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
@ -766,15 +814,17 @@ class BeitModel(BeitPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, (patch_height, patch_width) = self.embeddings(
embedding_output, _ = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
resolution = pixel_values.shape[2:]
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
resolution=resolution,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
@ -1477,9 +1527,14 @@ class BeitBackbone(BeitPreTrainedModel, BackboneMixin):
batch_size = pixel_values.shape[0]
embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values)
resolution = pixel_values.shape[2:]
outputs = self.encoder(
embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict
embedding_output,
output_hidden_states=True,
output_attentions=output_attentions,
resolution=resolution,
return_dict=return_dict,
)
hidden_states = outputs.hidden_states if return_dict else outputs[1]

View File

@ -32,7 +32,7 @@ from ...modeling_outputs import (
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
@ -192,12 +192,6 @@ class Data2VecVisionEmbeddings(nn.Module):
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
_, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings, (patch_height, patch_width) = self.patch_embeddings(
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
)
@ -281,6 +275,7 @@ class Data2VecVisionPatchEmbeddings(nn.Module):
class Data2VecVisionSelfAttention(nn.Module):
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
@ -314,6 +309,7 @@ class Data2VecVisionSelfAttention(nn.Module):
output_attentions: bool = False,
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
@ -328,9 +324,11 @@ class Data2VecVisionSelfAttention(nn.Module):
# Add relative position bias if present.
if self.relative_position_bias is not None:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
attention_scores = attention_scores + self.relative_position_bias(
interpolate_pos_encoding, attention_scores.shape[2]
).unsqueeze(0)
window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)
# Add shared relative position bias if provided.
if relative_position_bias is not None:
@ -410,9 +408,10 @@ class Data2VecVisionAttention(nn.Module):
output_attentions: bool = False,
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_outputs = self.attention(
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
)
attention_output = self.output(self_outputs[0], hidden_states)
@ -483,6 +482,7 @@ class Data2VecVisionLayer(nn.Module):
output_attentions: bool = False,
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
@ -490,6 +490,7 @@ class Data2VecVisionLayer(nn.Module):
output_attentions=output_attentions,
relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding,
resolution=resolution,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
@ -529,32 +530,71 @@ class Data2VecVisionRelativePositionBias(nn.Module):
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
self.relative_position_indices = {}
def generate_relative_position_index(self, window_size: Tuple[int, int]) -> torch.Tensor:
"""
This method creates the relative position index, modified to support arbitrary window sizes,
as introduced in [MiDaS v3.1](https://arxiv.org/abs/2307.14460).
"""
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
window_area = window_size[0] * window_size[1]
grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
coords = torch.stack(grid) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
)
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
return relative_position_index
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
"""
Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
"""
old_height = 2 * self.window_size[0] - 1
old_width = 2 * self.window_size[1] - 1
def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
) # Wh*Ww,Wh*Ww,nH
new_height = 2 * window_size[0] - 1
new_width = 2 * window_size[1] - 1
old_relative_position_bias_table = self.relative_position_bias_table
old_num_relative_distance = self.num_relative_distance
new_num_relative_distance = new_height * new_width + 3
old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
new_sub_table = nn.functional.interpolate(
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
)
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
new_relative_position_bias_table = torch.cat(
[new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
)
key = window_size
if key not in self.relative_position_indices.keys():
self.relative_position_indices[key] = self.generate_relative_position_index(window_size)
relative_position_bias = new_relative_position_bias_table[self.relative_position_indices[key].view(-1)]
# patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
relative_position_bias = relative_position_bias.view(
window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
)
# num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
if interpolate_pos_encoding:
relative_position_bias = nn.functional.interpolate(
relative_position_bias.unsqueeze(1),
@ -563,7 +603,7 @@ class Data2VecVisionRelativePositionBias(nn.Module):
align_corners=False,
).squeeze(1)
return relative_position_bias
return relative_position_bias.unsqueeze(0)
# Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision
@ -597,6 +637,7 @@ class Data2VecVisionEncoder(nn.Module):
output_attentions: bool = False,
output_hidden_states: bool = False,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
@ -616,13 +657,22 @@ class Data2VecVisionEncoder(nn.Module):
output_attentions,
)
else:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
relative_position_bias = (
self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1])
self.relative_position_bias(
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)
if self.relative_position_bias is not None
else None
)
layer_outputs = layer_module(
hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
hidden_states,
layer_head_mask,
output_attentions,
relative_position_bias,
interpolate_pos_encoding,
resolution,
)
hidden_states = layer_outputs[0]
@ -654,6 +704,7 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["Data2VecVisionLayer"]
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
def _init_weights(self, module):
"""Initialize the weights"""
@ -750,7 +801,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
@ -768,9 +819,6 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
@ -778,15 +826,17 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, (patch_height, patch_width) = self.embeddings(
embedding_output, _ = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
resolution = pixel_values.shape[2:]
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
resolution=resolution,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

View File

@ -58,7 +58,7 @@ def get_resize_output_image_size(
multiple: int,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]:
def constraint_to_multiple_of(val, multiple, min_val=0, max_val=None):
def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None):
x = round(val / multiple) * multiple
if max_val is not None and x > max_val:
@ -87,8 +87,8 @@ def get_resize_output_image_size(
# fit height
scale_width = scale_height
new_height = constraint_to_multiple_of(scale_height * input_height, multiple=multiple)
new_width = constraint_to_multiple_of(scale_width * input_width, multiple=multiple)
new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
return (new_height, new_width)

View File

@ -1021,7 +1021,7 @@ class DPTNeck(nn.Module):
class DPTDepthEstimationHead(nn.Module):
"""
Output head head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
supplementary material).
"""

View File

@ -0,0 +1,67 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available, is_vision_available
from ...utils import OptionalDependencyNotAvailable
_import_structure = {"configuration_zoedepth": ["ZOEDEPTH_PRETRAINED_CONFIG_ARCHIVE_MAP", "ZoeDepthConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_zoedepth"] = [
"ZoeDepthForDepthEstimation",
"ZoeDepthPreTrainedModel",
]
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_zoedepth"] = ["ZoeDepthImageProcessor"]
if TYPE_CHECKING:
from .configuration_zoedepth import ZOEDEPTH_PRETRAINED_CONFIG_ARCHIVE_MAP, ZoeDepthConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_zoedepth import (
ZoeDepthForDepthEstimation,
ZoeDepthPreTrainedModel,
)
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_zoedepth import ZoeDepthImageProcessor
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,234 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ZoeDepth model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto.configuration_auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
ZOEDEPTH_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"Intel/zoedepth-nyu": "https://huggingface.co/Intel/zoedepth-nyu/resolve/main/config.json",
}
class ZoeDepthConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ZoeDepthForDepthEstimation`]. It is used to instantiate an ZoeDepth
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the ZoeDepth
[Intel/zoedepth-nyu](https://huggingface.co/Intel/zoedepth-nyu) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*, defaults to `BeitConfig()`):
The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the batch normalization layers.
readout_type (`str`, *optional*, defaults to `"project"`):
The readout type to use when processing the readout token (CLS token) of the intermediate hidden states of
the ViT backbone. Can be one of [`"ignore"`, `"add"`, `"project"`].
- "ignore" simply ignores the CLS token.
- "add" passes the information from the CLS token to all other tokens by adding the representations.
- "project" passes information to the other tokens by concatenating the readout to all other tokens before
projecting the
representation to the original feature dimension D using a linear layer followed by a GELU non-linearity.
reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`):
The up/downsampling factors of the reassemble layers.
neck_hidden_sizes (`List[str]`, *optional*, defaults to `[96, 192, 384, 768]`):
The hidden sizes to project to for the feature maps of the backbone.
fusion_hidden_size (`int`, *optional*, defaults to 256):
The number of channels before fusion.
head_in_index (`int`, *optional*, defaults to -1):
The index of the features to use in the heads.
use_batch_norm_in_fusion_residual (`bool`, *optional*, defaults to `False`):
Whether to use batch normalization in the pre-activate residual units of the fusion blocks.
use_bias_in_fusion_residual (`bool`, *optional*, defaults to `True`):
Whether to use bias in the pre-activate residual units of the fusion blocks.
num_relative_features (`int`, *optional*, defaults to 32):
The number of features to use in the relative depth estimation head.
add_projection (`bool`, *optional*, defaults to `False`):
Whether to add a projection layer before the depth estimation head.
bottleneck_features (`int`, *optional*, defaults to 256):
The number of features in the bottleneck layer.
num_attractors (`List[int], *optional*, defaults to `[16, 8, 4, 1]`):
The number of attractors to use in each stage.
bin_embedding_dim (`int`, *optional*, defaults to 128):
The dimension of the bin embeddings.
attractor_alpha (`int`, *optional*, defaults to 1000):
The alpha value to use in the attractor.
attractor_gamma (`int`, *optional*, defaults to 2):
The gamma value to use in the attractor.
attractor_kind (`str`, *optional*, defaults to `"mean"`):
The kind of attractor to use. Can be one of [`"mean"`, `"sum"`].
min_temp (`float`, *optional*, defaults to 0.0212):
The minimum temperature value to consider.
max_temp (`float`, *optional*, defaults to 50.0):
The maximum temperature value to consider.
bin_centers_type (`str`, *optional*, defaults to `"softplus"`):
Activation type used for bin centers. Can be "normed" or "softplus". For "normed" bin centers, linear normalization trick
is applied. This results in bounded bin centers. For "softplus", softplus activation is used and thus are unbounded.
bin_configurations (`List[dict]`, *optional*, defaults to `[{'n_bins': 64, 'min_depth': 0.001, 'max_depth': 10.0}]`):
Configuration for each of the bin heads.
Each configuration should consist of the following keys:
- name (`str`): The name of the bin head - only required in case of multiple bin configurations.
- `n_bins` (`int`): The number of bins to use.
- `min_depth` (`float`): The minimum depth value to consider.
- `max_depth` (`float`): The maximum depth value to consider.
In case only a single configuration is passed, the model will use a single head with the specified configuration.
In case multiple configurations are passed, the model will use multiple heads with the specified configurations.
num_patch_transformer_layers (`int`, *optional*):
The number of transformer layers to use in the patch transformer. Only used in case of multiple bin configurations.
patch_transformer_hidden_size (`int`, *optional*):
The hidden size to use in the patch transformer. Only used in case of multiple bin configurations.
patch_transformer_intermediate_size (`int`, *optional*):
The intermediate size to use in the patch transformer. Only used in case of multiple bin configurations.
patch_transformer_num_attention_heads (`int`, *optional*):
The number of attention heads to use in the patch transformer. Only used in case of multiple bin configurations.
Example:
```python
>>> from transformers import ZoeDepthConfig, ZoeDepthForDepthEstimation
>>> # Initializing a ZoeDepth zoedepth-large style configuration
>>> configuration = ZoeDepthConfig()
>>> # Initializing a model from the zoedepth-large style configuration
>>> model = ZoeDepthForDepthEstimation(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "zoedepth"
def __init__(
self,
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
backbone_kwargs=None,
hidden_act="gelu",
initializer_range=0.02,
batch_norm_eps=1e-05,
readout_type="project",
reassemble_factors=[4, 2, 1, 0.5],
neck_hidden_sizes=[96, 192, 384, 768],
fusion_hidden_size=256,
head_in_index=-1,
use_batch_norm_in_fusion_residual=False,
use_bias_in_fusion_residual=None,
num_relative_features=32,
add_projection=False,
bottleneck_features=256,
num_attractors=[16, 8, 4, 1],
bin_embedding_dim=128,
attractor_alpha=1000,
attractor_gamma=2,
attractor_kind="mean",
min_temp=0.0212,
max_temp=50.0,
bin_centers_type="softplus",
bin_configurations=[{"n_bins": 64, "min_depth": 0.001, "max_depth": 10.0}],
num_patch_transformer_layers=None,
patch_transformer_hidden_size=None,
patch_transformer_intermediate_size=None,
patch_transformer_num_attention_heads=None,
**kwargs,
):
super().__init__(**kwargs)
if readout_type not in ["ignore", "add", "project"]:
raise ValueError("Readout_type must be one of ['ignore', 'add', 'project']")
if attractor_kind not in ["mean", "sum"]:
raise ValueError("Attractor_kind must be one of ['mean', 'sum']")
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `BEiT` backbone.")
backbone_config = CONFIG_MAPPING["beit"](
image_size=384,
num_hidden_layers=24,
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
use_relative_position_bias=True,
reshape_hidden_states=False,
out_features=["stage6", "stage12", "stage18", "stage24"],
)
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
self.backbone_config = backbone_config
self.backbone = backbone
self.hidden_act = hidden_act
self.use_pretrained_backbone = use_pretrained_backbone
self.initializer_range = initializer_range
self.batch_norm_eps = batch_norm_eps
self.readout_type = readout_type
self.reassemble_factors = reassemble_factors
self.neck_hidden_sizes = neck_hidden_sizes
self.fusion_hidden_size = fusion_hidden_size
self.head_in_index = head_in_index
self.use_batch_norm_in_fusion_residual = use_batch_norm_in_fusion_residual
self.use_bias_in_fusion_residual = use_bias_in_fusion_residual
self.num_relative_features = num_relative_features
self.add_projection = add_projection
self.bottleneck_features = bottleneck_features
self.num_attractors = num_attractors
self.bin_embedding_dim = bin_embedding_dim
self.attractor_alpha = attractor_alpha
self.attractor_gamma = attractor_gamma
self.attractor_kind = attractor_kind
self.min_temp = min_temp
self.max_temp = max_temp
self.bin_centers_type = bin_centers_type
self.bin_configurations = bin_configurations
self.num_patch_transformer_layers = num_patch_transformer_layers
self.patch_transformer_hidden_size = patch_transformer_hidden_size
self.patch_transformer_intermediate_size = patch_transformer_intermediate_size
self.patch_transformer_num_attention_heads = patch_transformer_num_attention_heads

View File

@ -0,0 +1,426 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ZoeDepth checkpoints from the original repository. URL: https://github.com/isl-org/ZoeDepth.
Original logits where obtained by running the following code:
!git clone -b understanding_zoedepth https://github.com/NielsRogge/ZoeDepth
!python inference.py
"""
import argparse
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import BeitConfig, ZoeDepthConfig, ZoeDepthForDepthEstimation, ZoeDepthImageProcessor
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_zoedepth_config(model_name):
image_size = 384
backbone_config = BeitConfig(
image_size=image_size,
num_hidden_layers=24,
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
use_relative_position_bias=True,
reshape_hidden_states=False,
out_features=["stage6", "stage12", "stage18", "stage24"], # beit-large-512 uses [5, 11, 17, 23],
)
neck_hidden_sizes = [256, 512, 1024, 1024]
bin_centers_type = "softplus" if model_name in ["ZoeD_N", "ZoeD_NK"] else "normed"
if model_name == "ZoeD_NK":
bin_configurations = [
{"name": "nyu", "n_bins": 64, "min_depth": 1e-3, "max_depth": 10.0},
{"name": "kitti", "n_bins": 64, "min_depth": 1e-3, "max_depth": 80.0},
]
elif model_name in ["ZoeD_N", "ZoeD_K"]:
bin_configurations = [
{"name": "nyu", "n_bins": 64, "min_depth": 1e-3, "max_depth": 10.0},
]
config = ZoeDepthConfig(
backbone_config=backbone_config,
neck_hidden_sizes=neck_hidden_sizes,
bin_centers_type=bin_centers_type,
bin_configurations=bin_configurations,
num_patch_transformer_layers=4 if model_name == "ZoeD_NK" else None,
patch_transformer_hidden_size=128 if model_name == "ZoeD_NK" else None,
patch_transformer_intermediate_size=1024 if model_name == "ZoeD_NK" else None,
patch_transformer_num_attention_heads=4 if model_name == "ZoeD_NK" else None,
)
return config, image_size
def rename_key(name):
# Transformer backbone
if "core.core.pretrained.model.blocks" in name:
name = name.replace("core.core.pretrained.model.blocks", "backbone.encoder.layer")
if "core.core.pretrained.model.patch_embed.proj" in name:
name = name.replace(
"core.core.pretrained.model.patch_embed.proj", "backbone.embeddings.patch_embeddings.projection"
)
if "core.core.pretrained.model.cls_token" in name:
name = name.replace("core.core.pretrained.model.cls_token", "backbone.embeddings.cls_token")
if "norm1" in name and "patch_transformer" not in name:
name = name.replace("norm1", "layernorm_before")
if "norm2" in name and "patch_transformer" not in name:
name = name.replace("norm2", "layernorm_after")
if "mlp.fc1" in name:
name = name.replace("mlp.fc1", "intermediate.dense")
if "mlp.fc2" in name:
name = name.replace("mlp.fc2", "output.dense")
if "gamma_1" in name:
name = name.replace("gamma_1", "lambda_1")
if "gamma_2" in name:
name = name.replace("gamma_2", "lambda_2")
if "attn.proj" in name:
name = name.replace("attn.proj", "attention.output.dense")
if "attn.relative_position_bias_table" in name:
name = name.replace(
"attn.relative_position_bias_table",
"attention.attention.relative_position_bias.relative_position_bias_table",
)
if "attn.relative_position_index" in name:
name = name.replace(
"attn.relative_position_index", "attention.attention.relative_position_bias.relative_position_index"
)
# activation postprocessing (readout projections + resize blocks)
if "core.core.pretrained.act_postprocess1.0.project" in name:
name = name.replace(
"core.core.pretrained.act_postprocess1.0.project", "neck.reassemble_stage.readout_projects.0"
)
if "core.core.pretrained.act_postprocess2.0.project" in name:
name = name.replace(
"core.core.pretrained.act_postprocess2.0.project", "neck.reassemble_stage.readout_projects.1"
)
if "core.core.pretrained.act_postprocess3.0.project" in name:
name = name.replace(
"core.core.pretrained.act_postprocess3.0.project", "neck.reassemble_stage.readout_projects.2"
)
if "core.core.pretrained.act_postprocess4.0.project" in name:
name = name.replace(
"core.core.pretrained.act_postprocess4.0.project", "neck.reassemble_stage.readout_projects.3"
)
if "core.core.pretrained.act_postprocess1.3" in name:
name = name.replace("core.core.pretrained.act_postprocess1.3", "neck.reassemble_stage.layers.0.projection")
if "core.core.pretrained.act_postprocess2.3" in name:
name = name.replace("core.core.pretrained.act_postprocess2.3", "neck.reassemble_stage.layers.1.projection")
if "core.core.pretrained.act_postprocess3.3" in name:
name = name.replace("core.core.pretrained.act_postprocess3.3", "neck.reassemble_stage.layers.2.projection")
if "core.core.pretrained.act_postprocess4.3" in name:
name = name.replace("core.core.pretrained.act_postprocess4.3", "neck.reassemble_stage.layers.3.projection")
if "core.core.pretrained.act_postprocess1.4" in name:
name = name.replace("core.core.pretrained.act_postprocess1.4", "neck.reassemble_stage.layers.0.resize")
if "core.core.pretrained.act_postprocess2.4" in name:
name = name.replace("core.core.pretrained.act_postprocess2.4", "neck.reassemble_stage.layers.1.resize")
if "core.core.pretrained.act_postprocess4.4" in name:
name = name.replace("core.core.pretrained.act_postprocess4.4", "neck.reassemble_stage.layers.3.resize")
# scratch convolutions
if "core.core.scratch.layer1_rn.weight" in name:
name = name.replace("core.core.scratch.layer1_rn.weight", "neck.convs.0.weight")
if "core.core.scratch.layer2_rn.weight" in name:
name = name.replace("core.core.scratch.layer2_rn.weight", "neck.convs.1.weight")
if "core.core.scratch.layer3_rn.weight" in name:
name = name.replace("core.core.scratch.layer3_rn.weight", "neck.convs.2.weight")
if "core.core.scratch.layer4_rn.weight" in name:
name = name.replace("core.core.scratch.layer4_rn.weight", "neck.convs.3.weight")
# fusion layers
# tricky here: mapping = {1:3, 2:2, 3:1, 4:0}
if "core.core.scratch.refinenet1" in name:
name = name.replace("core.core.scratch.refinenet1", "neck.fusion_stage.layers.3")
if "core.core.scratch.refinenet2" in name:
name = name.replace("core.core.scratch.refinenet2", "neck.fusion_stage.layers.2")
if "core.core.scratch.refinenet3" in name:
name = name.replace("core.core.scratch.refinenet3", "neck.fusion_stage.layers.1")
if "core.core.scratch.refinenet4" in name:
name = name.replace("core.core.scratch.refinenet4", "neck.fusion_stage.layers.0")
if "resConfUnit1" in name:
name = name.replace("resConfUnit1", "residual_layer1")
if "resConfUnit2" in name:
name = name.replace("resConfUnit2", "residual_layer2")
if "conv1" in name:
name = name.replace("conv1", "convolution1")
if "conv2" in name and "residual_layer" in name:
name = name.replace("conv2", "convolution2")
if "out_conv" in name:
name = name.replace("out_conv", "projection")
# relative depth estimation head
if "core.core.scratch.output_conv.0" in name:
name = name.replace("core.core.scratch.output_conv.0", "relative_head.conv1")
if "core.core.scratch.output_conv.2" in name:
name = name.replace("core.core.scratch.output_conv.2", "relative_head.conv2")
if "core.core.scratch.output_conv.4" in name:
name = name.replace("core.core.scratch.output_conv.4", "relative_head.conv3")
# patch transformer
if "patch_transformer" in name:
name = name.replace("patch_transformer", "metric_head.patch_transformer")
if "mlp_classifier.0" in name:
name = name.replace("mlp_classifier.0", "metric_head.mlp_classifier.linear1")
if "mlp_classifier.2" in name:
name = name.replace("mlp_classifier.2", "metric_head.mlp_classifier.linear2")
if "projectors" in name:
name = name.replace("projectors", "metric_head.projectors")
if "seed_bin_regressors" in name:
name = name.replace("seed_bin_regressors", "metric_head.seed_bin_regressors")
if "seed_bin_regressor" in name and "seed_bin_regressors" not in name:
name = name.replace("seed_bin_regressor", "metric_head.seed_bin_regressor")
if "seed_projector" in name:
name = name.replace("seed_projector", "metric_head.seed_projector")
if "_net.0" in name:
name = name.replace("_net.0", "conv1")
if "_net.2" in name:
name = name.replace("_net.2", "conv2")
if "attractors" in name:
name = name.replace("attractors", "metric_head.attractors")
if "conditional_log_binomial" in name:
name = name.replace("conditional_log_binomial", "metric_head.conditional_log_binomial")
# metric depth estimation head
if "conv2" in name and "metric_head" not in name and "attractors" not in name and "relative_head" not in name:
name = name.replace("conv2", "metric_head.conv2")
if "transformer_encoder.layers" in name:
name = name.replace("transformer_encoder.layers", "transformer_encoder")
return name
def read_in_q_k_v_metric_head(state_dict):
hidden_size = 128
for i in range(4):
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"patch_transformer.transformer_encoder.layers.{i}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"patch_transformer.transformer_encoder.layers.{i}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.query.weight"] = in_proj_weight[
:hidden_size, :
]
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.query.bias"] = in_proj_bias[:hidden_size]
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.key.weight"] = in_proj_weight[
hidden_size : hidden_size * 2, :
]
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.key.bias"] = in_proj_bias[
hidden_size : hidden_size * 2
]
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.value.weight"] = in_proj_weight[
-hidden_size:, :
]
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.value.bias"] = in_proj_bias[-hidden_size:]
def convert_state_dict(orig_state_dict):
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
# rename key
new_name = rename_key(key)
orig_state_dict[new_name] = val
return orig_state_dict
def remove_ignore_keys(state_dict):
for key, _ in state_dict.copy().items():
if (
"fc_norm" in key
or "relative_position_index" in key
or "k_idx" in key
or "K_minus_1" in key
or "core.core.pretrained.model.head" in key
):
state_dict.pop(key, None)
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config):
hidden_size = config.backbone_config.hidden_size
for i in range(config.backbone_config.num_hidden_layers):
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"core.core.pretrained.model.blocks.{i}.attn.qkv.weight")
q_bias = state_dict.pop(f"core.core.pretrained.model.blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"core.core.pretrained.model.blocks.{i}.attn.v_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[:hidden_size, :]
state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
hidden_size : hidden_size * 2, :
]
state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-hidden_size:, :]
state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
# We will verify our results on an image
def prepare_img():
filepath = hf_hub_download(repo_id="shariqfarooq/ZoeDepth", filename="examples/person_1.jpeg", repo_type="space")
image = Image.open(filepath).convert("RGB")
return image
@torch.no_grad()
def convert_zoedepth_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub):
"""
Copy/paste/tweak model's weights to our ZoeDepth structure.
"""
# define ZoeDepth configuration based on URL
config, _ = get_zoedepth_config(model_name)
# load original model
original_model = torch.hub.load(
"NielsRogge/ZoeDepth:understanding_zoedepth", model_name, pretrained=True, force_reload=True
)
original_model.eval()
state_dict = original_model.state_dict()
print("Original state dict:")
for name, param in state_dict.items():
print(name, param.shape)
# read in qkv matrices
read_in_q_k_v(state_dict, config)
if model_name == "ZoeD_NK":
read_in_q_k_v_metric_head(state_dict)
# rename keys
state_dict = convert_state_dict(state_dict)
# remove certain keys
remove_ignore_keys(state_dict)
# load HuggingFace model
model = ZoeDepthForDepthEstimation(config)
model.load_state_dict(state_dict)
model.eval()
# verify image processor
image = prepare_img()
image_processor = ZoeDepthImageProcessor()
pixel_values = image_processor(image, return_tensors="pt").pixel_values
filepath = hf_hub_download(
repo_id="nielsr/test-image",
filename="zoedepth_pixel_values.pt",
repo_type="dataset",
)
original_pixel_values = torch.load(filepath, map_location="cpu")
assert torch.allclose(pixel_values, original_pixel_values)
# verify logits
# this was done on a resized version of the cats image (384x384)
filepath = hf_hub_download(
repo_id="nielsr/test-image",
filename="zoedepth_pixel_values.pt",
repo_type="dataset",
revision="1865dbb81984f01c89e83eec10f8d07efd10743d",
)
cats_pixel_values = torch.load(filepath, map_location="cpu")
depth = model(cats_pixel_values).predicted_depth
# Verify logits
# These were obtained by inserting the pixel_values at the patch embeddings of BEiT
if model_name == "ZoeD_N":
expected_shape = torch.Size([1, 384, 384])
expected_slice = torch.tensor([[1.0328, 1.0604, 1.0747], [1.0816, 1.1293, 1.1456], [1.1117, 1.1629, 1.1766]])
elif model_name == "ZoeD_K":
expected_shape = torch.Size([1, 384, 384])
expected_slice = torch.tensor([[1.6567, 1.6852, 1.7065], [1.6707, 1.6764, 1.6713], [1.7195, 1.7166, 1.7118]])
elif model_name == "ZoeD_NK":
expected_shape = torch.Size([1, 384, 384])
expected_slice = torch.tensor([[1.1228, 1.1079, 1.1382], [1.1807, 1.1658, 1.1891], [1.2344, 1.2094, 1.2317]])
print("Shape of depth:", depth.shape)
print("First 3x3 slice of depth:", depth[0, :3, :3])
assert depth.shape == torch.Size(expected_shape)
assert torch.allclose(depth[0, :3, :3], expected_slice, atol=1e-4)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
print(f"Saving model and processor to {pytorch_dump_folder_path}")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
image_processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
model_name_to_repo_id = {
"ZoeD_N": "zoedepth-nyu",
"ZoeD_K": "zoedepth-kitti",
"ZoeD_NK": "zoedepth-nyu-kitti",
}
print("Pushing model and processor to the hub...")
repo_id = model_name_to_repo_id[model_name]
model.push_to_hub(f"Intel/{repo_id}")
image_processor = ZoeDepthImageProcessor()
image_processor.push_to_hub(f"Intel/{repo_id}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="ZoeD_N",
choices=["ZoeD_N", "ZoeD_K", "ZoeD_NK"],
type=str,
help="Name of the original ZoeDepth checkpoint you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=False,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
)
args = parser.parse_args()
convert_zoedepth_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)

View File

@ -0,0 +1,454 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for ZoeDepth."""
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import PaddingMode, pad, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_available, is_vision_available, logging, requires_backends
if is_vision_available():
import PIL
if is_torch_available():
import torch
from torch import nn
logger = logging.get_logger(__name__)
def get_resize_output_image_size(
input_image: np.ndarray,
output_size: Union[int, Iterable[int]],
keep_aspect_ratio: bool,
multiple: int,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]:
def constrain_to_multiple_of(val, multiple, min_val=0):
x = (np.round(val / multiple) * multiple).astype(int)
if x < min_val:
x = math.ceil(val / multiple) * multiple
return x
output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
input_height, input_width = get_image_size(input_image, input_data_format)
output_height, output_width = output_size
# determine new height and width
scale_height = output_height / input_height
scale_width = output_width / input_width
if keep_aspect_ratio:
# scale as little as possible
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
return (new_height, new_width)
class ZoeDepthImageProcessor(BaseImageProcessor):
r"""
Constructs a ZoeDepth image processor.
Args:
do_pad (`bool`, *optional*, defaults to `True`):
Whether to apply pad the input.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in
`preprocess`.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overidden by `rescale_factor` in `preprocess`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`.
size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 512}`):
Size of the image after resizing. Size of the image after resizing. If `keep_aspect_ratio` is `True`,
the image is resized by choosing the smaller of the height and width scaling factors and using it for both dimensions.
If `ensure_multiple_of` is also set, the image is further resized to a size that is a multiple of this value.
Can be overidden by `size` in `preprocess`.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
keep_aspect_ratio (`bool`, *optional*, defaults to `True`):
If `True`, the image is resized by choosing the smaller of the height and width scaling factors and using it for
both dimensions. This ensures that the image is scaled down as little as possible while still fitting within the
desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a size that is a
multiple of this value by flooring the height and width to the nearest multiple of this value.
Can be overidden by `keep_aspect_ratio` in `preprocess`.
ensure_multiple_of (`int`, *optional*, defaults to 32):
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by flooring
the height and width to the nearest multiple of this value.
Works both with and without `keep_aspect_ratio` being set to `True`. Can be overidden by `ensure_multiple_of`
in `preprocess`.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_pad: bool = True,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
keep_aspect_ratio: bool = True,
ensure_multiple_of: int = 32,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_pad = do_pad
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
size = size if size is not None else {"height": 384, "width": 512}
size = get_size_dict(size)
self.do_resize = do_resize
self.size = size
self.keep_aspect_ratio = keep_aspect_ratio
self.ensure_multiple_of = ensure_multiple_of
self.resample = resample
self._valid_processor_keys = [
"images",
"do_resize",
"size",
"keep_aspect_ratio",
"ensure_multiple_of",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_pad",
"return_tensors",
"data_format",
"input_data_format",
]
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
keep_aspect_ratio: bool = False,
ensure_multiple_of: int = 1,
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
set, the image is resized to a size that is a multiple of this value.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Target size of the output image.
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
ensure_multiple_of (`int`, *optional*, defaults to 1):
The image is resized to a size that is a multiple of this value.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
specified in `size`.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
data_format = data_format if data_format is not None else input_data_format
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
output_size = get_resize_output_image_size(
image,
output_size=(size["height"], size["width"]),
keep_aspect_ratio=keep_aspect_ratio,
multiple=ensure_multiple_of,
input_data_format=input_data_format,
)
height, width = output_size
torch_image = torch.from_numpy(image).unsqueeze(0)
torch_image = torch_image.permute(0, 3, 1, 2) if input_data_format == "channels_last" else torch_image
# TODO support align_corners=True in image_transforms.resize
requires_backends(self, "torch")
resample_to_mode = {PILImageResampling.BILINEAR: "bilinear", PILImageResampling.BICUBIC: "bicubic"}
mode = resample_to_mode[resample]
resized_image = nn.functional.interpolate(
torch_image, (int(height), int(width)), mode=mode, align_corners=True
)
resized_image = resized_image.squeeze().numpy()
resized_image = to_channel_dimension_format(
resized_image, data_format, input_channel_dim=ChannelDimension.FIRST
)
return resized_image
def pad_image(
self,
image: np.array,
mode: PaddingMode = PaddingMode.REFLECT,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Pad an image as done in the original ZoeDepth implementation.
Padding fixes the boundary artifacts in the output depth map.
Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset
which has a black or white border around the image. This function pads the input image and crops
the prediction back to the original size / view.
Args:
image (`np.ndarray`):
Image to pad.
mode (`PaddingMode`):
The padding mode to use. Can be one of:
- `"constant"`: pads with a constant value.
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
vector along each axis.
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
height, width = get_image_size(image, input_data_format)
pad_height = int(np.sqrt(height / 2) * 3)
pad_width = int(np.sqrt(width / 2) * 3)
return pad(
image,
padding=((pad_height, pad_height), (pad_width, pad_width)),
mode=mode,
data_format=data_format,
input_data_format=input_data_format,
)
def preprocess(
self,
images: ImageInput,
do_pad: bool = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_resize: bool = None,
size: int = None,
keep_aspect_ratio: bool = None,
ensure_multiple_of: int = None,
resample: PILImageResampling = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
Whether to pad the input image.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. If `keep_aspect_ratio` is `True`, he image is resized by choosing the smaller of
the height and width scaling factors and using it for both dimensions. If `ensure_multiple_of` is also set,
the image is further resized to a size that is a multiple of this value.
keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`):
If `True` and `do_resize=True`, the image is resized by choosing the smaller of the height and width scaling factors and using it for
both dimensions. This ensures that the image is scaled down as little as possible while still fitting within the
desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a size that is a
multiple of this value by flooring the height and width to the nearest multiple of this value.
ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`):
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by flooring
the height and width to the nearest multiple of this value.
Works both with and without `keep_aspect_ratio` being set to `True`. Can be overidden by `ensure_multiple_of` in `preprocess`.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size)
keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio
ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_pad = do_pad if do_pad is not None else self.do_pad
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_rescale:
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_pad:
images = [self.pad_image(image=image, input_data_format=input_data_format) for image in images]
if do_resize:
images = [
self.resize(
image=image,
size=size,
resample=resample,
keep_aspect_ratio=keep_aspect_ratio,
ensure_multiple_of=ensure_multiple_of,
input_data_format=input_data_format,
)
for image in images
]
if do_normalize:
images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

File diff suppressed because it is too large Load Diff

View File

@ -9660,6 +9660,20 @@ class YosoPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class ZoeDepthForDepthEstimation(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ZoeDepthPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Adafactor(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -651,3 +651,10 @@ class YolosImageProcessor(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class ZoeDepthImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])

View File

View File

@ -0,0 +1,187 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.file_utils import is_vision_available
from transformers.testing_utils import require_torch, require_vision
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_vision_available():
from transformers import ZoeDepthImageProcessor
class ZoeDepthImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_resize=True,
size=None,
ensure_multiple_of=32,
keep_aspect_ratio=False,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
do_pad=False,
):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.ensure_multiple_of = ensure_multiple_of
self.keep_aspect_ratio = keep_aspect_ratio
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_pad = do_pad
def prepare_image_processor_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"ensure_multiple_of": self.ensure_multiple_of,
"keep_aspect_ratio": self.keep_aspect_ratio,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_pad": self.do_pad,
}
def expected_output_image_shape(self, images):
return self.num_channels, self.ensure_multiple_of, self.ensure_multiple_of
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = ZoeDepthImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = ZoeDepthImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "ensure_multiple_of"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_pad"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
def test_ensure_multiple_of(self):
# Test variable by turning off all other variables which affect the size, size which is not multiple of 32
image = np.zeros((489, 640, 3))
size = {"height": 380, "width": 513}
multiple = 32
image_processor = ZoeDepthImageProcessor(
do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False
)
pixel_values = image_processor(image, return_tensors="pt").pixel_values
self.assertEqual(list(pixel_values.shape), [1, 3, 384, 512])
self.assertTrue(pixel_values.shape[2] % multiple == 0)
self.assertTrue(pixel_values.shape[3] % multiple == 0)
# Test variable by turning off all other variables which affect the size, size which is already multiple of 32
image = np.zeros((511, 511, 3))
height, width = 512, 512
size = {"height": height, "width": width}
multiple = 32
image_processor = ZoeDepthImageProcessor(
do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False
)
pixel_values = image_processor(image, return_tensors="pt").pixel_values
self.assertEqual(list(pixel_values.shape), [1, 3, height, width])
self.assertTrue(pixel_values.shape[2] % multiple == 0)
self.assertTrue(pixel_values.shape[3] % multiple == 0)
def test_keep_aspect_ratio(self):
# Test `keep_aspect_ratio=True` by turning off all other variables which affect the size
height, width = 489, 640
image = np.zeros((height, width, 3))
size = {"height": 512, "width": 512}
image_processor = ZoeDepthImageProcessor(do_pad=False, keep_aspect_ratio=True, size=size, ensure_multiple_of=1)
pixel_values = image_processor(image, return_tensors="pt").pixel_values
# As can be seen, the image is resized to the maximum size that fits in the specified size
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 670])
# Test `keep_aspect_ratio=False` by turning off all other variables which affect the size
image_processor = ZoeDepthImageProcessor(
do_pad=False, keep_aspect_ratio=False, size=size, ensure_multiple_of=1
)
pixel_values = image_processor(image, return_tensors="pt").pixel_values
# As can be seen, the size is respected
self.assertEqual(list(pixel_values.shape), [1, 3, size["height"], size["width"]])
# Test `keep_aspect_ratio=True` with `ensure_multiple_of` set
image = np.zeros((489, 640, 3))
size = {"height": 511, "width": 511}
multiple = 32
image_processor = ZoeDepthImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=multiple)
pixel_values = image_processor(image, return_tensors="pt").pixel_values
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672])
self.assertTrue(pixel_values.shape[2] % multiple == 0)
self.assertTrue(pixel_values.shape[3] % multiple == 0)

View File

@ -0,0 +1,257 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch ZoeDepth model."""
import unittest
from transformers import Dinov2Config, ZoeDepthConfig
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import ZoeDepthForDepthEstimation
if is_vision_available():
from PIL import Image
from transformers import ZoeDepthImageProcessor
class ZoeDepthModelTester:
def __init__(
self,
parent,
batch_size=2,
num_channels=3,
image_size=32,
patch_size=16,
use_labels=True,
num_labels=3,
is_training=True,
hidden_size=4,
num_hidden_layers=2,
num_attention_heads=2,
intermediate_size=8,
out_features=["stage1", "stage2"],
apply_layernorm=False,
reshape_hidden_states=False,
neck_hidden_sizes=[2, 2],
fusion_hidden_size=6,
bottleneck_features=6,
num_out_features=[6, 6, 6, 6],
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.patch_size = patch_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.out_features = out_features
self.apply_layernorm = apply_layernorm
self.reshape_hidden_states = reshape_hidden_states
self.use_labels = use_labels
self.num_labels = num_labels
self.is_training = is_training
self.neck_hidden_sizes = neck_hidden_sizes
self.fusion_hidden_size = fusion_hidden_size
self.bottleneck_features = bottleneck_features
self.num_out_features = num_out_features
# ZoeDepth's sequence length
self.seq_length = (self.image_size // self.patch_size) ** 2 + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return ZoeDepthConfig(
backbone_config=self.get_backbone_config(),
backbone=None,
neck_hidden_sizes=self.neck_hidden_sizes,
fusion_hidden_size=self.fusion_hidden_size,
bottleneck_features=self.bottleneck_features,
num_out_features=self.num_out_features,
)
def get_backbone_config(self):
return Dinov2Config(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
is_training=self.is_training,
out_features=self.out_features,
reshape_hidden_states=self.reshape_hidden_states,
)
def create_and_check_for_depth_estimation(self, config, pixel_values, labels):
config.num_labels = self.num_labels
model = ZoeDepthForDepthEstimation(config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(result.predicted_depth.shape, (self.batch_size, self.image_size, self.image_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class ZoeDepthModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as ZoeDepth does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (ZoeDepthForDepthEstimation,) if is_torch_available() else ()
pipeline_model_mapping = {"depth-estimation": ZoeDepthForDepthEstimation} if is_torch_available() else {}
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = ZoeDepthModelTester(self)
self.config_tester = ConfigTester(
self, config_class=ZoeDepthConfig, has_text_modality=False, hidden_size=37, common_properties=[]
)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="ZoeDepth with AutoBackbone does not have a base model and hence no input_embeddings")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="ZoeDepth with AutoBackbone does not have a base model and hence no input_embeddings")
def test_model_get_set_embeddings(self):
pass
def test_for_depth_estimation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_depth_estimation(*config_and_inputs)
@unittest.skip(reason="ZoeDepth with AutoBackbone does not have a base model and hence no input_embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="ZoeDepth with AutoBackbone does not have a base model")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="ZoeDepth with AutoBackbone does not have a base model")
def test_save_load_fast_init_to_base(self):
pass
@unittest.skip(reason="ZoeDepth does not support training yet")
def test_training(self):
pass
@unittest.skip(reason="ZoeDepth does not support training yet")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="ZoeDepth does not support training yet")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="ZoeDepth does not support training yet")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@slow
def test_model_from_pretrained(self):
model_name = "Intel/zoedepth-nyu"
model = ZoeDepthForDepthEstimation.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_torch
@require_vision
@slow
class ZoeDepthModelIntegrationTest(unittest.TestCase):
def test_inference_depth_estimation(self):
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu")
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu").to(torch_device)
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# verify the predicted depth
expected_shape = torch.Size((1, 384, 512))
self.assertEqual(predicted_depth.shape, expected_shape)
expected_slice = torch.tensor(
[[1.0020, 1.0219, 1.0389], [1.0349, 1.0816, 1.1000], [1.0576, 1.1094, 1.1249]],
).to(torch_device)
self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4))
def test_inference_depth_estimation_multiple_heads(self):
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti")
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti").to(torch_device)
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# verify the predicted depth
expected_shape = torch.Size((1, 384, 512))
self.assertEqual(predicted_depth.shape, expected_shape)
expected_slice = torch.tensor(
[[1.1571, 1.1438, 1.1783], [1.2163, 1.2036, 1.2320], [1.2688, 1.2461, 1.2734]],
).to(torch_device)
self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4))