mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add ZoeDepth (#30136)
* First draft * Add docs * Clean up code * Convert model * Add image processor * Convert Zoe_K * More improvements * Improve variable names and docstrings * Improve variable names * Improve variable names * Replace nn.sequential * More improvements * Convert ZoeD_NK * Fix most tests * Verify pixel values * Verify pixel values * Add squeeze * Update beit to support arbitrary window sizes * Improve image processor * Improve docstring * Improve beit * Improve model outputs * Add figure * Fix beit * Update checkpoint * Fix repo id * Add _keys_to_ignore_on_load_unexpected * More improvements * Address comments * Address comments * Address comments * Address comments * Rename variable name * Add backbone_hidden_size * Vectorize * Vectorize more * Address comments * Clarify docstring * Remove backbone_hidden_size * Fix image processor * Remove print statements * Remove print statement * Add integration test * Address comments * Address comments * Address comments * Address comments * Add requires_backends * Clean up * Simplify conversion script * Simplify more * Simplify more * Simplify more * Clean up * Make sure beit is loaded correctly * Address comment * Address bin_configurations * Use bin_configurations * Convert models, add integration tests * Fix doc test * Address comments * Unify regressor classes * Clarify arguments * Improve resize_image * Add num_relative_features * Address comment * [run-slow]beit,data2vec,zoedepth * [run-slow]beit,data2vec,zoedepth * Address comments * Address comment * Address comment * Replace nn.TransformerEncoderLayer and nn.TransformerEncoder * Replace nn.MultiheadAttention * Add attributes for patch transformer to config * Add tests for ensure_multiple_of * Update organization * Add tests * [run-slow] beit data2vec * Update ruff * [run-slow] beit data2vec * Add comment * Improve docstrings, add test * Fix interpolate_pos_encoding * Fix slow tests * Add docstring * Update src/transformers/models/zoedepth/image_processing_zoedepth.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/zoedepth/image_processing_zoedepth.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Improve tests and docstrings * Use run_common_tests * Improve docstrings * Improve docstrings * Improve tests * Improve tests * Remove print statements --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
1082361a19
commit
06fd7972ac
@ -667,6 +667,8 @@
|
||||
title: ViTMSN
|
||||
- local: model_doc/yolos
|
||||
title: YOLOS
|
||||
- local: model_doc/zoedepth
|
||||
title: ZoeDepth
|
||||
title: Vision models
|
||||
- isExpanded: false
|
||||
sections:
|
||||
|
@ -343,5 +343,6 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [XLSR-Wav2Vec2](model_doc/xlsr_wav2vec2) | ✅ | ✅ | ✅ |
|
||||
| [YOLOS](model_doc/yolos) | ✅ | ❌ | ❌ |
|
||||
| [YOSO](model_doc/yoso) | ✅ | ❌ | ❌ |
|
||||
| [ZoeDepth](model_doc/zoedepth) | ✅ | ❌ | ❌ |
|
||||
|
||||
<!-- End table-->
|
||||
|
108
docs/source/en/model_doc/zoedepth.md
Normal file
108
docs/source/en/model_doc/zoedepth.md
Normal file
@ -0,0 +1,108 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# ZoeDepth
|
||||
|
||||
## Overview
|
||||
|
||||
The ZoeDepth model was proposed in [ZoeDepth: Zero-shot Transfer by Combining Relative and Metric Depth](https://arxiv.org/abs/2302.12288) by Shariq Farooq Bhat, Reiner Birkl, Diana Wofk, Peter Wonka, Matthias Müller. ZoeDepth extends the [DPT](dpt) framework for metric (also called absolute) depth estimation. ZoeDepth is pre-trained on 12 datasets using relative depth and fine-tuned on two domains (NYU and KITTI) using metric depth. A lightweight head is used with a novel bin adjustment design called metric bins module for each domain. During inference, each input image is automatically routed to the appropriate head using a latent classifier.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*This paper tackles the problem of depth estimation from a single image. Existing work either focuses on generalization performance disregarding metric scale, i.e. relative depth estimation, or state-of-the-art results on specific datasets, i.e. metric depth estimation. We propose the first approach that combines both worlds, leading to a model with excellent generalization performance while maintaining metric scale. Our flagship model, ZoeD-M12-NK, is pre-trained on 12 datasets using relative depth and fine-tuned on two datasets using metric depth. We use a lightweight head with a novel bin adjustment design called metric bins module for each domain. During inference, each input image is automatically routed to the appropriate head using a latent classifier. Our framework admits multiple configurations depending on the datasets used for relative depth pre-training and metric fine-tuning. Without pre-training, we can already significantly improve the state of the art (SOTA) on the NYU Depth v2 indoor dataset. Pre-training on twelve datasets and fine-tuning on the NYU Depth v2 indoor dataset, we can further improve SOTA for a total of 21% in terms of relative absolute error (REL). Finally, ZoeD-M12-NK is the first model that can jointly train on multiple datasets (NYU Depth v2 and KITTI) without a significant drop in performance and achieve unprecedented zero-shot generalization performance to eight unseen datasets from both indoor and outdoor domains.*
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/zoedepth_architecture_bis.png"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> ZoeDepth architecture. Taken from the <a href="https://arxiv.org/abs/2302.12288">original paper.</a> </small>
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr).
|
||||
The original code can be found [here](https://github.com/isl-org/ZoeDepth).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- ZoeDepth is an absolute (also called metric) depth estimation model, unlike DPT which is a relative depth estimation model. This means that ZoeDepth is able to estimate depth in metric units like meters.
|
||||
|
||||
The easiest to perform inference with ZoeDepth is by leveraging the [pipeline API](../main_classes/pipelines.md):
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
pipe = pipeline(task="depth-estimation", model="Intel/zoedepth-nyu-kitti")
|
||||
result = pipe(image)
|
||||
depth = result["depth"]
|
||||
```
|
||||
|
||||
Alternatively, one can also perform inference using the classes:
|
||||
|
||||
```python
|
||||
from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
image_processor = AutoImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti")
|
||||
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti")
|
||||
|
||||
# prepare image for the model
|
||||
inputs = image_processor(images=image, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
predicted_depth = outputs.predicted_depth
|
||||
|
||||
# interpolate to original size
|
||||
prediction = torch.nn.functional.interpolate(
|
||||
predicted_depth.unsqueeze(1),
|
||||
size=image.size[::-1],
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
# visualize the prediction
|
||||
output = prediction.squeeze().cpu().numpy()
|
||||
formatted = (output * 255 / np.max(output)).astype("uint8")
|
||||
depth = Image.fromarray(formatted)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ZoeDepth.
|
||||
|
||||
- A demo notebook regarding inference with ZoeDepth models can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/ZoeDepth). 🌎
|
||||
|
||||
## ZoeDepthConfig
|
||||
|
||||
[[autodoc]] ZoeDepthConfig
|
||||
|
||||
## ZoeDepthImageProcessor
|
||||
|
||||
[[autodoc]] ZoeDepthImageProcessor
|
||||
- preprocess
|
||||
|
||||
## ZoeDepthForDepthEstimation
|
||||
|
||||
[[autodoc]] ZoeDepthForDepthEstimation
|
||||
- forward
|
@ -807,6 +807,7 @@ _import_structure = {
|
||||
"models.xmod": ["XmodConfig"],
|
||||
"models.yolos": ["YolosConfig"],
|
||||
"models.yoso": ["YosoConfig"],
|
||||
"models.zoedepth": ["ZoeDepthConfig"],
|
||||
"onnx": [],
|
||||
"pipelines": [
|
||||
"AudioClassificationPipeline",
|
||||
@ -1182,6 +1183,7 @@ else:
|
||||
_import_structure["models.vitmatte"].append("VitMatteImageProcessor")
|
||||
_import_structure["models.vivit"].append("VivitImageProcessor")
|
||||
_import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
|
||||
_import_structure["models.zoedepth"].append("ZoeDepthImageProcessor")
|
||||
|
||||
try:
|
||||
if not is_torchvision_available():
|
||||
@ -3586,6 +3588,12 @@ else:
|
||||
"YosoPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.zoedepth"].extend(
|
||||
[
|
||||
"ZoeDepthForDepthEstimation",
|
||||
"ZoeDepthPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["optimization"] = [
|
||||
"Adafactor",
|
||||
"AdamW",
|
||||
@ -5497,6 +5505,7 @@ if TYPE_CHECKING:
|
||||
from .models.xmod import XmodConfig
|
||||
from .models.yolos import YolosConfig
|
||||
from .models.yoso import YosoConfig
|
||||
from .models.zoedepth import ZoeDepthConfig
|
||||
|
||||
# Pipelines
|
||||
from .pipelines import (
|
||||
@ -5872,6 +5881,7 @@ if TYPE_CHECKING:
|
||||
from .models.vitmatte import VitMatteImageProcessor
|
||||
from .models.vivit import VivitImageProcessor
|
||||
from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
|
||||
from .models.zoedepth import ZoeDepthImageProcessor
|
||||
|
||||
try:
|
||||
if not is_torchvision_available():
|
||||
@ -7798,6 +7808,10 @@ if TYPE_CHECKING:
|
||||
YosoModel,
|
||||
YosoPreTrainedModel,
|
||||
)
|
||||
from .models.zoedepth import (
|
||||
ZoeDepthForDepthEstimation,
|
||||
ZoeDepthPreTrainedModel,
|
||||
)
|
||||
|
||||
# Optimization
|
||||
from .optimization import (
|
||||
|
@ -409,22 +409,22 @@ def validate_preprocess_arguments(
|
||||
|
||||
"""
|
||||
if do_rescale and rescale_factor is None:
|
||||
raise ValueError("rescale_factor must be specified if do_rescale is True.")
|
||||
raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
|
||||
|
||||
if do_pad and size_divisibility is None:
|
||||
# Here, size_divisor might be passed as the value of size
|
||||
raise ValueError(
|
||||
"Depending on moel, size_divisibility, size_divisor, pad_size or size must be specified if do_pad is True."
|
||||
"Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
|
||||
)
|
||||
|
||||
if do_normalize and (image_mean is None or image_std is None):
|
||||
raise ValueError("image_mean and image_std must both be specified if do_normalize is True.")
|
||||
raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
|
||||
|
||||
if do_center_crop and crop_size is None:
|
||||
raise ValueError("crop_size must be specified if do_center_crop is True.")
|
||||
raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
|
||||
|
||||
if do_resize and (size is None or resample is None):
|
||||
raise ValueError("size and resample must be specified if do_resize is True.")
|
||||
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
|
||||
|
||||
|
||||
# In the future we can add a TF implementation here when we have TF models.
|
||||
|
@ -263,4 +263,5 @@ from . import (
|
||||
xmod,
|
||||
yolos,
|
||||
yoso,
|
||||
zoedepth,
|
||||
)
|
||||
|
@ -291,6 +291,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("xmod", "XmodConfig"),
|
||||
("yolos", "YolosConfig"),
|
||||
("yoso", "YosoConfig"),
|
||||
("zoedepth", "ZoeDepthConfig"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -589,6 +590,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("xmod", "X-MOD"),
|
||||
("yolos", "YOLOS"),
|
||||
("yoso", "YOSO"),
|
||||
("zoedepth", "ZoeDepth"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -142,6 +142,7 @@ else:
|
||||
("vitmatte", ("VitMatteImageProcessor",)),
|
||||
("xclip", ("CLIPImageProcessor",)),
|
||||
("yolos", ("YolosImageProcessor",)),
|
||||
("zoedepth", ("ZoeDepthImageProcessor",)),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -792,6 +792,7 @@ MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
|
||||
("depth_anything", "DepthAnythingForDepthEstimation"),
|
||||
("dpt", "DPTForDepthEstimation"),
|
||||
("glpn", "GLPNForDepthEstimation"),
|
||||
("zoedepth", "ZoeDepthForDepthEstimation"),
|
||||
]
|
||||
)
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
|
@ -34,7 +34,7 @@ from ...modeling_outputs import (
|
||||
SemanticSegmenterOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -193,12 +193,6 @@ class BeitEmbeddings(nn.Module):
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
_, _, height, width = pixel_values.shape
|
||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model"
|
||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
|
||||
embeddings, (patch_height, patch_width) = self.patch_embeddings(
|
||||
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
|
||||
)
|
||||
@ -280,6 +274,7 @@ class BeitPatchEmbeddings(nn.Module):
|
||||
class BeitSelfAttention(nn.Module):
|
||||
def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
||||
@ -313,6 +308,7 @@ class BeitSelfAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
@ -327,9 +323,11 @@ class BeitSelfAttention(nn.Module):
|
||||
|
||||
# Add relative position bias if present.
|
||||
if self.relative_position_bias is not None:
|
||||
height, width = resolution
|
||||
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
||||
attention_scores = attention_scores + self.relative_position_bias(
|
||||
interpolate_pos_encoding, attention_scores.shape[2]
|
||||
).unsqueeze(0)
|
||||
window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
|
||||
)
|
||||
|
||||
# Add shared relative position bias if provided.
|
||||
if relative_position_bias is not None:
|
||||
@ -407,9 +405,10 @@ class BeitAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
self_outputs = self.attention(
|
||||
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
|
||||
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
|
||||
)
|
||||
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
@ -475,6 +474,7 @@ class BeitLayer(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
self_attention_outputs = self.attention(
|
||||
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
|
||||
@ -482,6 +482,7 @@ class BeitLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
relative_position_bias=relative_position_bias,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
resolution=resolution,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
@ -520,32 +521,71 @@ class BeitRelativePositionBias(nn.Module):
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
self.relative_position_indices = {}
|
||||
|
||||
def generate_relative_position_index(self, window_size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
This method creates the relative position index, modified to support arbitrary window sizes,
|
||||
as introduced in [MiDaS v3.1](https://arxiv.org/abs/2307.14460).
|
||||
"""
|
||||
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
|
||||
window_area = window_size[0] * window_size[1]
|
||||
grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
|
||||
coords = torch.stack(grid) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
|
||||
)
|
||||
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
relative_position_index[0, 0:] = num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = num_relative_distance - 2
|
||||
relative_position_index[0, 0] = num_relative_distance - 1
|
||||
return relative_position_index
|
||||
|
||||
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
|
||||
def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
|
||||
"""
|
||||
Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
|
||||
"""
|
||||
old_height = 2 * self.window_size[0] - 1
|
||||
old_width = 2 * self.window_size[1] - 1
|
||||
|
||||
def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor:
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
new_height = 2 * window_size[0] - 1
|
||||
new_width = 2 * window_size[1] - 1
|
||||
|
||||
old_relative_position_bias_table = self.relative_position_bias_table
|
||||
|
||||
old_num_relative_distance = self.num_relative_distance
|
||||
new_num_relative_distance = new_height * new_width + 3
|
||||
|
||||
old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
|
||||
|
||||
old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
|
||||
new_sub_table = nn.functional.interpolate(
|
||||
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
|
||||
)
|
||||
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
|
||||
|
||||
new_relative_position_bias_table = torch.cat(
|
||||
[new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
|
||||
)
|
||||
|
||||
key = window_size
|
||||
if key not in self.relative_position_indices.keys():
|
||||
self.relative_position_indices[key] = self.generate_relative_position_index(window_size)
|
||||
|
||||
relative_position_bias = new_relative_position_bias_table[self.relative_position_indices[key].view(-1)]
|
||||
# patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
|
||||
relative_position_bias = relative_position_bias.view(
|
||||
window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
|
||||
)
|
||||
# num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
||||
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
if interpolate_pos_encoding:
|
||||
relative_position_bias = nn.functional.interpolate(
|
||||
relative_position_bias.unsqueeze(1),
|
||||
@ -554,7 +594,7 @@ class BeitRelativePositionBias(nn.Module):
|
||||
align_corners=False,
|
||||
).squeeze(1)
|
||||
|
||||
return relative_position_bias
|
||||
return relative_position_bias.unsqueeze(0)
|
||||
|
||||
|
||||
class BeitEncoder(nn.Module):
|
||||
@ -587,6 +627,7 @@ class BeitEncoder(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@ -606,13 +647,22 @@ class BeitEncoder(nn.Module):
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
height, width = resolution
|
||||
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
||||
relative_position_bias = (
|
||||
self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1])
|
||||
self.relative_position_bias(
|
||||
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
|
||||
)
|
||||
if self.relative_position_bias is not None
|
||||
else None
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
relative_position_bias,
|
||||
interpolate_pos_encoding,
|
||||
resolution,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -643,6 +693,7 @@ class BeitPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BeitLayer"]
|
||||
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -738,7 +789,7 @@ class BeitModel(BeitPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -756,9 +807,6 @@ class BeitModel(BeitPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
@ -766,15 +814,17 @@ class BeitModel(BeitPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output, (patch_height, patch_width) = self.embeddings(
|
||||
embedding_output, _ = self.embeddings(
|
||||
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
resolution = pixel_values.shape[2:]
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
resolution=resolution,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
@ -1477,9 +1527,14 @@ class BeitBackbone(BeitPreTrainedModel, BackboneMixin):
|
||||
|
||||
batch_size = pixel_values.shape[0]
|
||||
embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values)
|
||||
resolution = pixel_values.shape[2:]
|
||||
|
||||
outputs = self.encoder(
|
||||
embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict
|
||||
embedding_output,
|
||||
output_hidden_states=True,
|
||||
output_attentions=output_attentions,
|
||||
resolution=resolution,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
||||
|
@ -32,7 +32,7 @@ from ...modeling_outputs import (
|
||||
SemanticSegmenterOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -192,12 +192,6 @@ class Data2VecVisionEmbeddings(nn.Module):
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> torch.Tensor:
|
||||
_, _, height, width = pixel_values.shape
|
||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model"
|
||||
f" ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
|
||||
embeddings, (patch_height, patch_width) = self.patch_embeddings(
|
||||
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
|
||||
)
|
||||
@ -281,6 +275,7 @@ class Data2VecVisionPatchEmbeddings(nn.Module):
|
||||
class Data2VecVisionSelfAttention(nn.Module):
|
||||
def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
||||
@ -314,6 +309,7 @@ class Data2VecVisionSelfAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
@ -328,9 +324,11 @@ class Data2VecVisionSelfAttention(nn.Module):
|
||||
|
||||
# Add relative position bias if present.
|
||||
if self.relative_position_bias is not None:
|
||||
height, width = resolution
|
||||
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
||||
attention_scores = attention_scores + self.relative_position_bias(
|
||||
interpolate_pos_encoding, attention_scores.shape[2]
|
||||
).unsqueeze(0)
|
||||
window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
|
||||
)
|
||||
|
||||
# Add shared relative position bias if provided.
|
||||
if relative_position_bias is not None:
|
||||
@ -410,9 +408,10 @@ class Data2VecVisionAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
self_outputs = self.attention(
|
||||
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
|
||||
hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
|
||||
)
|
||||
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
@ -483,6 +482,7 @@ class Data2VecVisionLayer(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
self_attention_outputs = self.attention(
|
||||
self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
|
||||
@ -490,6 +490,7 @@ class Data2VecVisionLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
relative_position_bias=relative_position_bias,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
resolution=resolution,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
@ -529,32 +530,71 @@ class Data2VecVisionRelativePositionBias(nn.Module):
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
self.relative_position_indices = {}
|
||||
|
||||
def generate_relative_position_index(self, window_size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
This method creates the relative position index, modified to support arbitrary window sizes,
|
||||
as introduced in [MiDaS v3.1](https://arxiv.org/abs/2307.14460).
|
||||
"""
|
||||
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
|
||||
window_area = window_size[0] * window_size[1]
|
||||
grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
|
||||
coords = torch.stack(grid) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
|
||||
)
|
||||
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
relative_position_index[0, 0:] = num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = num_relative_distance - 2
|
||||
relative_position_index[0, 0] = num_relative_distance - 1
|
||||
return relative_position_index
|
||||
|
||||
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
|
||||
def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
|
||||
"""
|
||||
Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
|
||||
"""
|
||||
old_height = 2 * self.window_size[0] - 1
|
||||
old_width = 2 * self.window_size[1] - 1
|
||||
|
||||
def forward(self, interpolate_pos_encoding: bool = False, dim_size: Optional[int] = None) -> torch.Tensor:
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
new_height = 2 * window_size[0] - 1
|
||||
new_width = 2 * window_size[1] - 1
|
||||
|
||||
old_relative_position_bias_table = self.relative_position_bias_table
|
||||
|
||||
old_num_relative_distance = self.num_relative_distance
|
||||
new_num_relative_distance = new_height * new_width + 3
|
||||
|
||||
old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
|
||||
|
||||
old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
|
||||
new_sub_table = nn.functional.interpolate(
|
||||
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
|
||||
)
|
||||
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
|
||||
|
||||
new_relative_position_bias_table = torch.cat(
|
||||
[new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
|
||||
)
|
||||
|
||||
key = window_size
|
||||
if key not in self.relative_position_indices.keys():
|
||||
self.relative_position_indices[key] = self.generate_relative_position_index(window_size)
|
||||
|
||||
relative_position_bias = new_relative_position_bias_table[self.relative_position_indices[key].view(-1)]
|
||||
# patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
|
||||
relative_position_bias = relative_position_bias.view(
|
||||
window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
|
||||
)
|
||||
# num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
||||
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
if interpolate_pos_encoding:
|
||||
relative_position_bias = nn.functional.interpolate(
|
||||
relative_position_bias.unsqueeze(1),
|
||||
@ -563,7 +603,7 @@ class Data2VecVisionRelativePositionBias(nn.Module):
|
||||
align_corners=False,
|
||||
).squeeze(1)
|
||||
|
||||
return relative_position_bias
|
||||
return relative_position_bias.unsqueeze(0)
|
||||
|
||||
|
||||
# Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision
|
||||
@ -597,6 +637,7 @@ class Data2VecVisionEncoder(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
resolution: Optional[Tuple[int]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@ -616,13 +657,22 @@ class Data2VecVisionEncoder(nn.Module):
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
height, width = resolution
|
||||
window_size = (height // self.config.patch_size, width // self.config.patch_size)
|
||||
relative_position_bias = (
|
||||
self.relative_position_bias(interpolate_pos_encoding, hidden_states.shape[1])
|
||||
self.relative_position_bias(
|
||||
window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
|
||||
)
|
||||
if self.relative_position_bias is not None
|
||||
else None
|
||||
)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, layer_head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
relative_position_bias,
|
||||
interpolate_pos_encoding,
|
||||
resolution,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -654,6 +704,7 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Data2VecVisionLayer"]
|
||||
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -750,7 +801,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.Tensor,
|
||||
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -768,9 +819,6 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
@ -778,15 +826,17 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output, (patch_height, patch_width) = self.embeddings(
|
||||
embedding_output, _ = self.embeddings(
|
||||
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
resolution = pixel_values.shape[2:]
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
resolution=resolution,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
@ -58,7 +58,7 @@ def get_resize_output_image_size(
|
||||
multiple: int,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Tuple[int, int]:
|
||||
def constraint_to_multiple_of(val, multiple, min_val=0, max_val=None):
|
||||
def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None):
|
||||
x = round(val / multiple) * multiple
|
||||
|
||||
if max_val is not None and x > max_val:
|
||||
@ -87,8 +87,8 @@ def get_resize_output_image_size(
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
|
||||
new_height = constraint_to_multiple_of(scale_height * input_height, multiple=multiple)
|
||||
new_width = constraint_to_multiple_of(scale_width * input_width, multiple=multiple)
|
||||
new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
|
||||
new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
|
||||
|
||||
return (new_height, new_width)
|
||||
|
||||
|
@ -1021,7 +1021,7 @@ class DPTNeck(nn.Module):
|
||||
|
||||
class DPTDepthEstimationHead(nn.Module):
|
||||
"""
|
||||
Output head head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
|
||||
Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
|
||||
the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
|
||||
supplementary material).
|
||||
"""
|
||||
|
67
src/transformers/models/zoedepth/__init__.py
Normal file
67
src/transformers/models/zoedepth/__init__.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule, is_torch_available, is_vision_available
|
||||
from ...utils import OptionalDependencyNotAvailable
|
||||
|
||||
|
||||
_import_structure = {"configuration_zoedepth": ["ZOEDEPTH_PRETRAINED_CONFIG_ARCHIVE_MAP", "ZoeDepthConfig"]}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_zoedepth"] = [
|
||||
"ZoeDepthForDepthEstimation",
|
||||
"ZoeDepthPreTrainedModel",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["image_processing_zoedepth"] = ["ZoeDepthImageProcessor"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_zoedepth import ZOEDEPTH_PRETRAINED_CONFIG_ARCHIVE_MAP, ZoeDepthConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_zoedepth import (
|
||||
ZoeDepthForDepthEstimation,
|
||||
ZoeDepthPreTrainedModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .image_processing_zoedepth import ZoeDepthImageProcessor
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
234
src/transformers/models/zoedepth/configuration_zoedepth.py
Normal file
234
src/transformers/models/zoedepth/configuration_zoedepth.py
Normal file
@ -0,0 +1,234 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""ZoeDepth model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
ZOEDEPTH_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"Intel/zoedepth-nyu": "https://huggingface.co/Intel/zoedepth-nyu/resolve/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class ZoeDepthConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`ZoeDepthForDepthEstimation`]. It is used to instantiate an ZoeDepth
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the ZoeDepth
|
||||
[Intel/zoedepth-nyu](https://huggingface.co/Intel/zoedepth-nyu) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*, defaults to `BeitConfig()`):
|
||||
The configuration of the backbone model.
|
||||
backbone (`str`, *optional*):
|
||||
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
||||
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
||||
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
||||
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use pretrained weights for the backbone.
|
||||
backbone_kwargs (`dict`, *optional*):
|
||||
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
||||
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the batch normalization layers.
|
||||
readout_type (`str`, *optional*, defaults to `"project"`):
|
||||
The readout type to use when processing the readout token (CLS token) of the intermediate hidden states of
|
||||
the ViT backbone. Can be one of [`"ignore"`, `"add"`, `"project"`].
|
||||
|
||||
- "ignore" simply ignores the CLS token.
|
||||
- "add" passes the information from the CLS token to all other tokens by adding the representations.
|
||||
- "project" passes information to the other tokens by concatenating the readout to all other tokens before
|
||||
projecting the
|
||||
representation to the original feature dimension D using a linear layer followed by a GELU non-linearity.
|
||||
reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`):
|
||||
The up/downsampling factors of the reassemble layers.
|
||||
neck_hidden_sizes (`List[str]`, *optional*, defaults to `[96, 192, 384, 768]`):
|
||||
The hidden sizes to project to for the feature maps of the backbone.
|
||||
fusion_hidden_size (`int`, *optional*, defaults to 256):
|
||||
The number of channels before fusion.
|
||||
head_in_index (`int`, *optional*, defaults to -1):
|
||||
The index of the features to use in the heads.
|
||||
use_batch_norm_in_fusion_residual (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use batch normalization in the pre-activate residual units of the fusion blocks.
|
||||
use_bias_in_fusion_residual (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use bias in the pre-activate residual units of the fusion blocks.
|
||||
num_relative_features (`int`, *optional*, defaults to 32):
|
||||
The number of features to use in the relative depth estimation head.
|
||||
add_projection (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a projection layer before the depth estimation head.
|
||||
bottleneck_features (`int`, *optional*, defaults to 256):
|
||||
The number of features in the bottleneck layer.
|
||||
num_attractors (`List[int], *optional*, defaults to `[16, 8, 4, 1]`):
|
||||
The number of attractors to use in each stage.
|
||||
bin_embedding_dim (`int`, *optional*, defaults to 128):
|
||||
The dimension of the bin embeddings.
|
||||
attractor_alpha (`int`, *optional*, defaults to 1000):
|
||||
The alpha value to use in the attractor.
|
||||
attractor_gamma (`int`, *optional*, defaults to 2):
|
||||
The gamma value to use in the attractor.
|
||||
attractor_kind (`str`, *optional*, defaults to `"mean"`):
|
||||
The kind of attractor to use. Can be one of [`"mean"`, `"sum"`].
|
||||
min_temp (`float`, *optional*, defaults to 0.0212):
|
||||
The minimum temperature value to consider.
|
||||
max_temp (`float`, *optional*, defaults to 50.0):
|
||||
The maximum temperature value to consider.
|
||||
bin_centers_type (`str`, *optional*, defaults to `"softplus"`):
|
||||
Activation type used for bin centers. Can be "normed" or "softplus". For "normed" bin centers, linear normalization trick
|
||||
is applied. This results in bounded bin centers. For "softplus", softplus activation is used and thus are unbounded.
|
||||
bin_configurations (`List[dict]`, *optional*, defaults to `[{'n_bins': 64, 'min_depth': 0.001, 'max_depth': 10.0}]`):
|
||||
Configuration for each of the bin heads.
|
||||
Each configuration should consist of the following keys:
|
||||
- name (`str`): The name of the bin head - only required in case of multiple bin configurations.
|
||||
- `n_bins` (`int`): The number of bins to use.
|
||||
- `min_depth` (`float`): The minimum depth value to consider.
|
||||
- `max_depth` (`float`): The maximum depth value to consider.
|
||||
In case only a single configuration is passed, the model will use a single head with the specified configuration.
|
||||
In case multiple configurations are passed, the model will use multiple heads with the specified configurations.
|
||||
num_patch_transformer_layers (`int`, *optional*):
|
||||
The number of transformer layers to use in the patch transformer. Only used in case of multiple bin configurations.
|
||||
patch_transformer_hidden_size (`int`, *optional*):
|
||||
The hidden size to use in the patch transformer. Only used in case of multiple bin configurations.
|
||||
patch_transformer_intermediate_size (`int`, *optional*):
|
||||
The intermediate size to use in the patch transformer. Only used in case of multiple bin configurations.
|
||||
patch_transformer_num_attention_heads (`int`, *optional*):
|
||||
The number of attention heads to use in the patch transformer. Only used in case of multiple bin configurations.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import ZoeDepthConfig, ZoeDepthForDepthEstimation
|
||||
|
||||
>>> # Initializing a ZoeDepth zoedepth-large style configuration
|
||||
>>> configuration = ZoeDepthConfig()
|
||||
|
||||
>>> # Initializing a model from the zoedepth-large style configuration
|
||||
>>> model = ZoeDepthForDepthEstimation(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "zoedepth"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone_config=None,
|
||||
backbone=None,
|
||||
use_pretrained_backbone=False,
|
||||
backbone_kwargs=None,
|
||||
hidden_act="gelu",
|
||||
initializer_range=0.02,
|
||||
batch_norm_eps=1e-05,
|
||||
readout_type="project",
|
||||
reassemble_factors=[4, 2, 1, 0.5],
|
||||
neck_hidden_sizes=[96, 192, 384, 768],
|
||||
fusion_hidden_size=256,
|
||||
head_in_index=-1,
|
||||
use_batch_norm_in_fusion_residual=False,
|
||||
use_bias_in_fusion_residual=None,
|
||||
num_relative_features=32,
|
||||
add_projection=False,
|
||||
bottleneck_features=256,
|
||||
num_attractors=[16, 8, 4, 1],
|
||||
bin_embedding_dim=128,
|
||||
attractor_alpha=1000,
|
||||
attractor_gamma=2,
|
||||
attractor_kind="mean",
|
||||
min_temp=0.0212,
|
||||
max_temp=50.0,
|
||||
bin_centers_type="softplus",
|
||||
bin_configurations=[{"n_bins": 64, "min_depth": 0.001, "max_depth": 10.0}],
|
||||
num_patch_transformer_layers=None,
|
||||
patch_transformer_hidden_size=None,
|
||||
patch_transformer_intermediate_size=None,
|
||||
patch_transformer_num_attention_heads=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if readout_type not in ["ignore", "add", "project"]:
|
||||
raise ValueError("Readout_type must be one of ['ignore', 'add', 'project']")
|
||||
|
||||
if attractor_kind not in ["mean", "sum"]:
|
||||
raise ValueError("Attractor_kind must be one of ['mean', 'sum']")
|
||||
|
||||
if use_pretrained_backbone:
|
||||
raise ValueError("Pretrained backbones are not supported yet.")
|
||||
|
||||
if backbone_config is not None and backbone is not None:
|
||||
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
|
||||
|
||||
if backbone_config is None and backbone is None:
|
||||
logger.info("`backbone_config` is `None`. Initializing the config with the default `BEiT` backbone.")
|
||||
backbone_config = CONFIG_MAPPING["beit"](
|
||||
image_size=384,
|
||||
num_hidden_layers=24,
|
||||
hidden_size=1024,
|
||||
intermediate_size=4096,
|
||||
num_attention_heads=16,
|
||||
use_relative_position_bias=True,
|
||||
reshape_hidden_states=False,
|
||||
out_features=["stage6", "stage12", "stage18", "stage24"],
|
||||
)
|
||||
elif isinstance(backbone_config, dict):
|
||||
backbone_model_type = backbone_config.get("model_type")
|
||||
config_class = CONFIG_MAPPING[backbone_model_type]
|
||||
backbone_config = config_class.from_dict(backbone_config)
|
||||
|
||||
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
|
||||
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
|
||||
|
||||
self.backbone_config = backbone_config
|
||||
self.backbone = backbone
|
||||
self.hidden_act = hidden_act
|
||||
self.use_pretrained_backbone = use_pretrained_backbone
|
||||
self.initializer_range = initializer_range
|
||||
self.batch_norm_eps = batch_norm_eps
|
||||
self.readout_type = readout_type
|
||||
self.reassemble_factors = reassemble_factors
|
||||
self.neck_hidden_sizes = neck_hidden_sizes
|
||||
self.fusion_hidden_size = fusion_hidden_size
|
||||
self.head_in_index = head_in_index
|
||||
self.use_batch_norm_in_fusion_residual = use_batch_norm_in_fusion_residual
|
||||
self.use_bias_in_fusion_residual = use_bias_in_fusion_residual
|
||||
self.num_relative_features = num_relative_features
|
||||
self.add_projection = add_projection
|
||||
|
||||
self.bottleneck_features = bottleneck_features
|
||||
self.num_attractors = num_attractors
|
||||
self.bin_embedding_dim = bin_embedding_dim
|
||||
self.attractor_alpha = attractor_alpha
|
||||
self.attractor_gamma = attractor_gamma
|
||||
self.attractor_kind = attractor_kind
|
||||
self.min_temp = min_temp
|
||||
self.max_temp = max_temp
|
||||
self.bin_centers_type = bin_centers_type
|
||||
self.bin_configurations = bin_configurations
|
||||
self.num_patch_transformer_layers = num_patch_transformer_layers
|
||||
self.patch_transformer_hidden_size = patch_transformer_hidden_size
|
||||
self.patch_transformer_intermediate_size = patch_transformer_intermediate_size
|
||||
self.patch_transformer_num_attention_heads = patch_transformer_num_attention_heads
|
426
src/transformers/models/zoedepth/convert_zoedepth_to_hf.py
Normal file
426
src/transformers/models/zoedepth/convert_zoedepth_to_hf.py
Normal file
@ -0,0 +1,426 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ZoeDepth checkpoints from the original repository. URL: https://github.com/isl-org/ZoeDepth.
|
||||
|
||||
Original logits where obtained by running the following code:
|
||||
!git clone -b understanding_zoedepth https://github.com/NielsRogge/ZoeDepth
|
||||
!python inference.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import BeitConfig, ZoeDepthConfig, ZoeDepthForDepthEstimation, ZoeDepthImageProcessor
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_zoedepth_config(model_name):
|
||||
image_size = 384
|
||||
backbone_config = BeitConfig(
|
||||
image_size=image_size,
|
||||
num_hidden_layers=24,
|
||||
hidden_size=1024,
|
||||
intermediate_size=4096,
|
||||
num_attention_heads=16,
|
||||
use_relative_position_bias=True,
|
||||
reshape_hidden_states=False,
|
||||
out_features=["stage6", "stage12", "stage18", "stage24"], # beit-large-512 uses [5, 11, 17, 23],
|
||||
)
|
||||
|
||||
neck_hidden_sizes = [256, 512, 1024, 1024]
|
||||
bin_centers_type = "softplus" if model_name in ["ZoeD_N", "ZoeD_NK"] else "normed"
|
||||
if model_name == "ZoeD_NK":
|
||||
bin_configurations = [
|
||||
{"name": "nyu", "n_bins": 64, "min_depth": 1e-3, "max_depth": 10.0},
|
||||
{"name": "kitti", "n_bins": 64, "min_depth": 1e-3, "max_depth": 80.0},
|
||||
]
|
||||
elif model_name in ["ZoeD_N", "ZoeD_K"]:
|
||||
bin_configurations = [
|
||||
{"name": "nyu", "n_bins": 64, "min_depth": 1e-3, "max_depth": 10.0},
|
||||
]
|
||||
config = ZoeDepthConfig(
|
||||
backbone_config=backbone_config,
|
||||
neck_hidden_sizes=neck_hidden_sizes,
|
||||
bin_centers_type=bin_centers_type,
|
||||
bin_configurations=bin_configurations,
|
||||
num_patch_transformer_layers=4 if model_name == "ZoeD_NK" else None,
|
||||
patch_transformer_hidden_size=128 if model_name == "ZoeD_NK" else None,
|
||||
patch_transformer_intermediate_size=1024 if model_name == "ZoeD_NK" else None,
|
||||
patch_transformer_num_attention_heads=4 if model_name == "ZoeD_NK" else None,
|
||||
)
|
||||
|
||||
return config, image_size
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
# Transformer backbone
|
||||
if "core.core.pretrained.model.blocks" in name:
|
||||
name = name.replace("core.core.pretrained.model.blocks", "backbone.encoder.layer")
|
||||
if "core.core.pretrained.model.patch_embed.proj" in name:
|
||||
name = name.replace(
|
||||
"core.core.pretrained.model.patch_embed.proj", "backbone.embeddings.patch_embeddings.projection"
|
||||
)
|
||||
if "core.core.pretrained.model.cls_token" in name:
|
||||
name = name.replace("core.core.pretrained.model.cls_token", "backbone.embeddings.cls_token")
|
||||
if "norm1" in name and "patch_transformer" not in name:
|
||||
name = name.replace("norm1", "layernorm_before")
|
||||
if "norm2" in name and "patch_transformer" not in name:
|
||||
name = name.replace("norm2", "layernorm_after")
|
||||
if "mlp.fc1" in name:
|
||||
name = name.replace("mlp.fc1", "intermediate.dense")
|
||||
if "mlp.fc2" in name:
|
||||
name = name.replace("mlp.fc2", "output.dense")
|
||||
if "gamma_1" in name:
|
||||
name = name.replace("gamma_1", "lambda_1")
|
||||
if "gamma_2" in name:
|
||||
name = name.replace("gamma_2", "lambda_2")
|
||||
if "attn.proj" in name:
|
||||
name = name.replace("attn.proj", "attention.output.dense")
|
||||
if "attn.relative_position_bias_table" in name:
|
||||
name = name.replace(
|
||||
"attn.relative_position_bias_table",
|
||||
"attention.attention.relative_position_bias.relative_position_bias_table",
|
||||
)
|
||||
if "attn.relative_position_index" in name:
|
||||
name = name.replace(
|
||||
"attn.relative_position_index", "attention.attention.relative_position_bias.relative_position_index"
|
||||
)
|
||||
|
||||
# activation postprocessing (readout projections + resize blocks)
|
||||
if "core.core.pretrained.act_postprocess1.0.project" in name:
|
||||
name = name.replace(
|
||||
"core.core.pretrained.act_postprocess1.0.project", "neck.reassemble_stage.readout_projects.0"
|
||||
)
|
||||
if "core.core.pretrained.act_postprocess2.0.project" in name:
|
||||
name = name.replace(
|
||||
"core.core.pretrained.act_postprocess2.0.project", "neck.reassemble_stage.readout_projects.1"
|
||||
)
|
||||
if "core.core.pretrained.act_postprocess3.0.project" in name:
|
||||
name = name.replace(
|
||||
"core.core.pretrained.act_postprocess3.0.project", "neck.reassemble_stage.readout_projects.2"
|
||||
)
|
||||
if "core.core.pretrained.act_postprocess4.0.project" in name:
|
||||
name = name.replace(
|
||||
"core.core.pretrained.act_postprocess4.0.project", "neck.reassemble_stage.readout_projects.3"
|
||||
)
|
||||
|
||||
if "core.core.pretrained.act_postprocess1.3" in name:
|
||||
name = name.replace("core.core.pretrained.act_postprocess1.3", "neck.reassemble_stage.layers.0.projection")
|
||||
if "core.core.pretrained.act_postprocess2.3" in name:
|
||||
name = name.replace("core.core.pretrained.act_postprocess2.3", "neck.reassemble_stage.layers.1.projection")
|
||||
if "core.core.pretrained.act_postprocess3.3" in name:
|
||||
name = name.replace("core.core.pretrained.act_postprocess3.3", "neck.reassemble_stage.layers.2.projection")
|
||||
if "core.core.pretrained.act_postprocess4.3" in name:
|
||||
name = name.replace("core.core.pretrained.act_postprocess4.3", "neck.reassemble_stage.layers.3.projection")
|
||||
|
||||
if "core.core.pretrained.act_postprocess1.4" in name:
|
||||
name = name.replace("core.core.pretrained.act_postprocess1.4", "neck.reassemble_stage.layers.0.resize")
|
||||
if "core.core.pretrained.act_postprocess2.4" in name:
|
||||
name = name.replace("core.core.pretrained.act_postprocess2.4", "neck.reassemble_stage.layers.1.resize")
|
||||
if "core.core.pretrained.act_postprocess4.4" in name:
|
||||
name = name.replace("core.core.pretrained.act_postprocess4.4", "neck.reassemble_stage.layers.3.resize")
|
||||
|
||||
# scratch convolutions
|
||||
if "core.core.scratch.layer1_rn.weight" in name:
|
||||
name = name.replace("core.core.scratch.layer1_rn.weight", "neck.convs.0.weight")
|
||||
if "core.core.scratch.layer2_rn.weight" in name:
|
||||
name = name.replace("core.core.scratch.layer2_rn.weight", "neck.convs.1.weight")
|
||||
if "core.core.scratch.layer3_rn.weight" in name:
|
||||
name = name.replace("core.core.scratch.layer3_rn.weight", "neck.convs.2.weight")
|
||||
if "core.core.scratch.layer4_rn.weight" in name:
|
||||
name = name.replace("core.core.scratch.layer4_rn.weight", "neck.convs.3.weight")
|
||||
|
||||
# fusion layers
|
||||
# tricky here: mapping = {1:3, 2:2, 3:1, 4:0}
|
||||
if "core.core.scratch.refinenet1" in name:
|
||||
name = name.replace("core.core.scratch.refinenet1", "neck.fusion_stage.layers.3")
|
||||
if "core.core.scratch.refinenet2" in name:
|
||||
name = name.replace("core.core.scratch.refinenet2", "neck.fusion_stage.layers.2")
|
||||
if "core.core.scratch.refinenet3" in name:
|
||||
name = name.replace("core.core.scratch.refinenet3", "neck.fusion_stage.layers.1")
|
||||
if "core.core.scratch.refinenet4" in name:
|
||||
name = name.replace("core.core.scratch.refinenet4", "neck.fusion_stage.layers.0")
|
||||
|
||||
if "resConfUnit1" in name:
|
||||
name = name.replace("resConfUnit1", "residual_layer1")
|
||||
|
||||
if "resConfUnit2" in name:
|
||||
name = name.replace("resConfUnit2", "residual_layer2")
|
||||
|
||||
if "conv1" in name:
|
||||
name = name.replace("conv1", "convolution1")
|
||||
|
||||
if "conv2" in name and "residual_layer" in name:
|
||||
name = name.replace("conv2", "convolution2")
|
||||
|
||||
if "out_conv" in name:
|
||||
name = name.replace("out_conv", "projection")
|
||||
|
||||
# relative depth estimation head
|
||||
if "core.core.scratch.output_conv.0" in name:
|
||||
name = name.replace("core.core.scratch.output_conv.0", "relative_head.conv1")
|
||||
|
||||
if "core.core.scratch.output_conv.2" in name:
|
||||
name = name.replace("core.core.scratch.output_conv.2", "relative_head.conv2")
|
||||
|
||||
if "core.core.scratch.output_conv.4" in name:
|
||||
name = name.replace("core.core.scratch.output_conv.4", "relative_head.conv3")
|
||||
|
||||
# patch transformer
|
||||
if "patch_transformer" in name:
|
||||
name = name.replace("patch_transformer", "metric_head.patch_transformer")
|
||||
|
||||
if "mlp_classifier.0" in name:
|
||||
name = name.replace("mlp_classifier.0", "metric_head.mlp_classifier.linear1")
|
||||
if "mlp_classifier.2" in name:
|
||||
name = name.replace("mlp_classifier.2", "metric_head.mlp_classifier.linear2")
|
||||
|
||||
if "projectors" in name:
|
||||
name = name.replace("projectors", "metric_head.projectors")
|
||||
|
||||
if "seed_bin_regressors" in name:
|
||||
name = name.replace("seed_bin_regressors", "metric_head.seed_bin_regressors")
|
||||
|
||||
if "seed_bin_regressor" in name and "seed_bin_regressors" not in name:
|
||||
name = name.replace("seed_bin_regressor", "metric_head.seed_bin_regressor")
|
||||
|
||||
if "seed_projector" in name:
|
||||
name = name.replace("seed_projector", "metric_head.seed_projector")
|
||||
|
||||
if "_net.0" in name:
|
||||
name = name.replace("_net.0", "conv1")
|
||||
|
||||
if "_net.2" in name:
|
||||
name = name.replace("_net.2", "conv2")
|
||||
|
||||
if "attractors" in name:
|
||||
name = name.replace("attractors", "metric_head.attractors")
|
||||
|
||||
if "conditional_log_binomial" in name:
|
||||
name = name.replace("conditional_log_binomial", "metric_head.conditional_log_binomial")
|
||||
|
||||
# metric depth estimation head
|
||||
if "conv2" in name and "metric_head" not in name and "attractors" not in name and "relative_head" not in name:
|
||||
name = name.replace("conv2", "metric_head.conv2")
|
||||
|
||||
if "transformer_encoder.layers" in name:
|
||||
name = name.replace("transformer_encoder.layers", "transformer_encoder")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def read_in_q_k_v_metric_head(state_dict):
|
||||
hidden_size = 128
|
||||
for i in range(4):
|
||||
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"patch_transformer.transformer_encoder.layers.{i}.self_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop(f"patch_transformer.transformer_encoder.layers.{i}.self_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.query.weight"] = in_proj_weight[
|
||||
:hidden_size, :
|
||||
]
|
||||
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.query.bias"] = in_proj_bias[:hidden_size]
|
||||
|
||||
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.key.weight"] = in_proj_weight[
|
||||
hidden_size : hidden_size * 2, :
|
||||
]
|
||||
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.key.bias"] = in_proj_bias[
|
||||
hidden_size : hidden_size * 2
|
||||
]
|
||||
|
||||
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.value.weight"] = in_proj_weight[
|
||||
-hidden_size:, :
|
||||
]
|
||||
state_dict[f"patch_transformer.transformer_encoder.{i}.self_attn.value.bias"] = in_proj_bias[-hidden_size:]
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict):
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
|
||||
# rename key
|
||||
new_name = rename_key(key)
|
||||
orig_state_dict[new_name] = val
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
def remove_ignore_keys(state_dict):
|
||||
for key, _ in state_dict.copy().items():
|
||||
if (
|
||||
"fc_norm" in key
|
||||
or "relative_position_index" in key
|
||||
or "k_idx" in key
|
||||
or "K_minus_1" in key
|
||||
or "core.core.pretrained.model.head" in key
|
||||
):
|
||||
state_dict.pop(key, None)
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_q_k_v(state_dict, config):
|
||||
hidden_size = config.backbone_config.hidden_size
|
||||
for i in range(config.backbone_config.num_hidden_layers):
|
||||
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"core.core.pretrained.model.blocks.{i}.attn.qkv.weight")
|
||||
q_bias = state_dict.pop(f"core.core.pretrained.model.blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"core.core.pretrained.model.blocks.{i}.attn.v_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[:hidden_size, :]
|
||||
state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
|
||||
state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||
hidden_size : hidden_size * 2, :
|
||||
]
|
||||
state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-hidden_size:, :]
|
||||
state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
|
||||
|
||||
|
||||
# We will verify our results on an image
|
||||
def prepare_img():
|
||||
filepath = hf_hub_download(repo_id="shariqfarooq/ZoeDepth", filename="examples/person_1.jpeg", repo_type="space")
|
||||
image = Image.open(filepath).convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_zoedepth_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our ZoeDepth structure.
|
||||
"""
|
||||
|
||||
# define ZoeDepth configuration based on URL
|
||||
config, _ = get_zoedepth_config(model_name)
|
||||
|
||||
# load original model
|
||||
original_model = torch.hub.load(
|
||||
"NielsRogge/ZoeDepth:understanding_zoedepth", model_name, pretrained=True, force_reload=True
|
||||
)
|
||||
original_model.eval()
|
||||
state_dict = original_model.state_dict()
|
||||
|
||||
print("Original state dict:")
|
||||
for name, param in state_dict.items():
|
||||
print(name, param.shape)
|
||||
|
||||
# read in qkv matrices
|
||||
read_in_q_k_v(state_dict, config)
|
||||
if model_name == "ZoeD_NK":
|
||||
read_in_q_k_v_metric_head(state_dict)
|
||||
|
||||
# rename keys
|
||||
state_dict = convert_state_dict(state_dict)
|
||||
# remove certain keys
|
||||
remove_ignore_keys(state_dict)
|
||||
|
||||
# load HuggingFace model
|
||||
model = ZoeDepthForDepthEstimation(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
# verify image processor
|
||||
image = prepare_img()
|
||||
|
||||
image_processor = ZoeDepthImageProcessor()
|
||||
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
||||
filepath = hf_hub_download(
|
||||
repo_id="nielsr/test-image",
|
||||
filename="zoedepth_pixel_values.pt",
|
||||
repo_type="dataset",
|
||||
)
|
||||
original_pixel_values = torch.load(filepath, map_location="cpu")
|
||||
assert torch.allclose(pixel_values, original_pixel_values)
|
||||
|
||||
# verify logits
|
||||
# this was done on a resized version of the cats image (384x384)
|
||||
filepath = hf_hub_download(
|
||||
repo_id="nielsr/test-image",
|
||||
filename="zoedepth_pixel_values.pt",
|
||||
repo_type="dataset",
|
||||
revision="1865dbb81984f01c89e83eec10f8d07efd10743d",
|
||||
)
|
||||
cats_pixel_values = torch.load(filepath, map_location="cpu")
|
||||
depth = model(cats_pixel_values).predicted_depth
|
||||
|
||||
# Verify logits
|
||||
# These were obtained by inserting the pixel_values at the patch embeddings of BEiT
|
||||
if model_name == "ZoeD_N":
|
||||
expected_shape = torch.Size([1, 384, 384])
|
||||
expected_slice = torch.tensor([[1.0328, 1.0604, 1.0747], [1.0816, 1.1293, 1.1456], [1.1117, 1.1629, 1.1766]])
|
||||
elif model_name == "ZoeD_K":
|
||||
expected_shape = torch.Size([1, 384, 384])
|
||||
expected_slice = torch.tensor([[1.6567, 1.6852, 1.7065], [1.6707, 1.6764, 1.6713], [1.7195, 1.7166, 1.7118]])
|
||||
elif model_name == "ZoeD_NK":
|
||||
expected_shape = torch.Size([1, 384, 384])
|
||||
expected_slice = torch.tensor([[1.1228, 1.1079, 1.1382], [1.1807, 1.1658, 1.1891], [1.2344, 1.2094, 1.2317]])
|
||||
|
||||
print("Shape of depth:", depth.shape)
|
||||
print("First 3x3 slice of depth:", depth[0, :3, :3])
|
||||
|
||||
assert depth.shape == torch.Size(expected_shape)
|
||||
assert torch.allclose(depth[0, :3, :3], expected_slice, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
print(f"Saving model and processor to {pytorch_dump_folder_path}")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
model_name_to_repo_id = {
|
||||
"ZoeD_N": "zoedepth-nyu",
|
||||
"ZoeD_K": "zoedepth-kitti",
|
||||
"ZoeD_NK": "zoedepth-nyu-kitti",
|
||||
}
|
||||
|
||||
print("Pushing model and processor to the hub...")
|
||||
repo_id = model_name_to_repo_id[model_name]
|
||||
model.push_to_hub(f"Intel/{repo_id}")
|
||||
image_processor = ZoeDepthImageProcessor()
|
||||
image_processor.push_to_hub(f"Intel/{repo_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="ZoeD_N",
|
||||
choices=["ZoeD_N", "ZoeD_K", "ZoeD_NK"],
|
||||
type=str,
|
||||
help="Name of the original ZoeDepth checkpoint you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_zoedepth_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
454
src/transformers/models/zoedepth/image_processing_zoedepth.py
Normal file
454
src/transformers/models/zoedepth/image_processing_zoedepth.py
Normal file
@ -0,0 +1,454 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for ZoeDepth."""
|
||||
|
||||
import math
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import PaddingMode, pad, to_channel_dimension_format
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, is_torch_available, is_vision_available, logging, requires_backends
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_resize_output_image_size(
|
||||
input_image: np.ndarray,
|
||||
output_size: Union[int, Iterable[int]],
|
||||
keep_aspect_ratio: bool,
|
||||
multiple: int,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> Tuple[int, int]:
|
||||
def constrain_to_multiple_of(val, multiple, min_val=0):
|
||||
x = (np.round(val / multiple) * multiple).astype(int)
|
||||
|
||||
if x < min_val:
|
||||
x = math.ceil(val / multiple) * multiple
|
||||
|
||||
return x
|
||||
|
||||
output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
|
||||
|
||||
input_height, input_width = get_image_size(input_image, input_data_format)
|
||||
output_height, output_width = output_size
|
||||
|
||||
# determine new height and width
|
||||
scale_height = output_height / input_height
|
||||
scale_width = output_width / input_width
|
||||
|
||||
if keep_aspect_ratio:
|
||||
# scale as little as possible
|
||||
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||
# fit width
|
||||
scale_height = scale_width
|
||||
else:
|
||||
# fit height
|
||||
scale_width = scale_height
|
||||
|
||||
new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
|
||||
new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
|
||||
|
||||
return (new_height, new_width)
|
||||
|
||||
|
||||
class ZoeDepthImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a ZoeDepth image processor.
|
||||
|
||||
Args:
|
||||
do_pad (`bool`, *optional*, defaults to `True`):
|
||||
Whether to apply pad the input.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in
|
||||
`preprocess`.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overidden by `rescale_factor` in `preprocess`.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||
method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 512}`):
|
||||
Size of the image after resizing. Size of the image after resizing. If `keep_aspect_ratio` is `True`,
|
||||
the image is resized by choosing the smaller of the height and width scaling factors and using it for both dimensions.
|
||||
If `ensure_multiple_of` is also set, the image is further resized to a size that is a multiple of this value.
|
||||
Can be overidden by `size` in `preprocess`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
|
||||
keep_aspect_ratio (`bool`, *optional*, defaults to `True`):
|
||||
If `True`, the image is resized by choosing the smaller of the height and width scaling factors and using it for
|
||||
both dimensions. This ensures that the image is scaled down as little as possible while still fitting within the
|
||||
desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a size that is a
|
||||
multiple of this value by flooring the height and width to the nearest multiple of this value.
|
||||
Can be overidden by `keep_aspect_ratio` in `preprocess`.
|
||||
ensure_multiple_of (`int`, *optional*, defaults to 32):
|
||||
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by flooring
|
||||
the height and width to the nearest multiple of this value.
|
||||
|
||||
Works both with and without `keep_aspect_ratio` being set to `True`. Can be overidden by `ensure_multiple_of`
|
||||
in `preprocess`.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_pad: bool = True,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
keep_aspect_ratio: bool = True,
|
||||
ensure_multiple_of: int = 32,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_pad = do_pad
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
||||
size = size if size is not None else {"height": 384, "width": 512}
|
||||
size = get_size_dict(size)
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.keep_aspect_ratio = keep_aspect_ratio
|
||||
self.ensure_multiple_of = ensure_multiple_of
|
||||
self.resample = resample
|
||||
|
||||
self._valid_processor_keys = [
|
||||
"images",
|
||||
"do_resize",
|
||||
"size",
|
||||
"keep_aspect_ratio",
|
||||
"ensure_multiple_of",
|
||||
"resample",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_pad",
|
||||
"return_tensors",
|
||||
"data_format",
|
||||
"input_data_format",
|
||||
]
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
keep_aspect_ratio: bool = False,
|
||||
ensure_multiple_of: int = 1,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
|
||||
is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
|
||||
set, the image is resized to a size that is a multiple of this value.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Target size of the output image.
|
||||
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
|
||||
ensure_multiple_of (`int`, *optional*, defaults to 1):
|
||||
The image is resized to a size that is a multiple of this value.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
|
||||
specified in `size`.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
"""
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
|
||||
data_format = data_format if data_format is not None else input_data_format
|
||||
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
|
||||
|
||||
output_size = get_resize_output_image_size(
|
||||
image,
|
||||
output_size=(size["height"], size["width"]),
|
||||
keep_aspect_ratio=keep_aspect_ratio,
|
||||
multiple=ensure_multiple_of,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
height, width = output_size
|
||||
|
||||
torch_image = torch.from_numpy(image).unsqueeze(0)
|
||||
torch_image = torch_image.permute(0, 3, 1, 2) if input_data_format == "channels_last" else torch_image
|
||||
|
||||
# TODO support align_corners=True in image_transforms.resize
|
||||
requires_backends(self, "torch")
|
||||
resample_to_mode = {PILImageResampling.BILINEAR: "bilinear", PILImageResampling.BICUBIC: "bicubic"}
|
||||
mode = resample_to_mode[resample]
|
||||
resized_image = nn.functional.interpolate(
|
||||
torch_image, (int(height), int(width)), mode=mode, align_corners=True
|
||||
)
|
||||
resized_image = resized_image.squeeze().numpy()
|
||||
|
||||
resized_image = to_channel_dimension_format(
|
||||
resized_image, data_format, input_channel_dim=ChannelDimension.FIRST
|
||||
)
|
||||
|
||||
return resized_image
|
||||
|
||||
def pad_image(
|
||||
self,
|
||||
image: np.array,
|
||||
mode: PaddingMode = PaddingMode.REFLECT,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Pad an image as done in the original ZoeDepth implementation.
|
||||
|
||||
Padding fixes the boundary artifacts in the output depth map.
|
||||
Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset
|
||||
which has a black or white border around the image. This function pads the input image and crops
|
||||
the prediction back to the original size / view.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to pad.
|
||||
mode (`PaddingMode`):
|
||||
The padding mode to use. Can be one of:
|
||||
- `"constant"`: pads with a constant value.
|
||||
- `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
|
||||
vector along each axis.
|
||||
- `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
|
||||
- `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
height, width = get_image_size(image, input_data_format)
|
||||
|
||||
pad_height = int(np.sqrt(height / 2) * 3)
|
||||
pad_width = int(np.sqrt(width / 2) * 3)
|
||||
|
||||
return pad(
|
||||
image,
|
||||
padding=((pad_height, pad_height), (pad_width, pad_width)),
|
||||
mode=mode,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_pad: bool = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_resize: bool = None,
|
||||
size: int = None,
|
||||
keep_aspect_ratio: bool = None,
|
||||
ensure_multiple_of: int = None,
|
||||
resample: PILImageResampling = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
||||
Whether to pad the input image.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image values between [0 - 1].
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. If `keep_aspect_ratio` is `True`, he image is resized by choosing the smaller of
|
||||
the height and width scaling factors and using it for both dimensions. If `ensure_multiple_of` is also set,
|
||||
the image is further resized to a size that is a multiple of this value.
|
||||
keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`):
|
||||
If `True` and `do_resize=True`, the image is resized by choosing the smaller of the height and width scaling factors and using it for
|
||||
both dimensions. This ensures that the image is scaled down as little as possible while still fitting within the
|
||||
desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a size that is a
|
||||
multiple of this value by flooring the height and width to the nearest multiple of this value.
|
||||
ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`):
|
||||
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by flooring
|
||||
the height and width to the nearest multiple of this value.
|
||||
|
||||
Works both with and without `keep_aspect_ratio` being set to `True`. Can be overidden by `ensure_multiple_of` in `preprocess`.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size)
|
||||
keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio
|
||||
ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_pad = do_pad if do_pad is not None else self.do_pad
|
||||
|
||||
images = make_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if is_scaled_image(images[0]) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_rescale:
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_pad:
|
||||
images = [self.pad_image(image=image, input_data_format=input_data_format) for image in images]
|
||||
|
||||
if do_resize:
|
||||
images = [
|
||||
self.resize(
|
||||
image=image,
|
||||
size=size,
|
||||
resample=resample,
|
||||
keep_aspect_ratio=keep_aspect_ratio,
|
||||
ensure_multiple_of=ensure_multiple_of,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_normalize:
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
1403
src/transformers/models/zoedepth/modeling_zoedepth.py
Normal file
1403
src/transformers/models/zoedepth/modeling_zoedepth.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -9660,6 +9660,20 @@ class YosoPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ZoeDepthForDepthEstimation(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ZoeDepthPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Adafactor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -651,3 +651,10 @@ class YolosImageProcessor(metaclass=DummyObject):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class ZoeDepthImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
0
tests/models/zoedepth/__init__.py
Normal file
0
tests/models/zoedepth/__init__.py
Normal file
187
tests/models/zoedepth/test_image_processing_zoedepth.py
Normal file
187
tests/models/zoedepth/test_image_processing_zoedepth.py
Normal file
@ -0,0 +1,187 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.file_utils import is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import ZoeDepthImageProcessor
|
||||
|
||||
|
||||
class ZoeDepthImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
ensure_multiple_of=32,
|
||||
keep_aspect_ratio=False,
|
||||
do_normalize=True,
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
do_pad=False,
|
||||
):
|
||||
size = size if size is not None else {"height": 18, "width": 18}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.ensure_multiple_of = ensure_multiple_of
|
||||
self.keep_aspect_ratio = keep_aspect_ratio
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_pad = do_pad
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"ensure_multiple_of": self.ensure_multiple_of,
|
||||
"keep_aspect_ratio": self.keep_aspect_ratio,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_pad": self.do_pad,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.num_channels, self.ensure_multiple_of, self.ensure_multiple_of
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class ZoeDepthImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = ZoeDepthImageProcessor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
self.image_processor_tester = ZoeDepthImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "ensure_multiple_of"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_pad"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
|
||||
def test_ensure_multiple_of(self):
|
||||
# Test variable by turning off all other variables which affect the size, size which is not multiple of 32
|
||||
image = np.zeros((489, 640, 3))
|
||||
|
||||
size = {"height": 380, "width": 513}
|
||||
multiple = 32
|
||||
image_processor = ZoeDepthImageProcessor(
|
||||
do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False
|
||||
)
|
||||
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
self.assertEqual(list(pixel_values.shape), [1, 3, 384, 512])
|
||||
self.assertTrue(pixel_values.shape[2] % multiple == 0)
|
||||
self.assertTrue(pixel_values.shape[3] % multiple == 0)
|
||||
|
||||
# Test variable by turning off all other variables which affect the size, size which is already multiple of 32
|
||||
image = np.zeros((511, 511, 3))
|
||||
|
||||
height, width = 512, 512
|
||||
size = {"height": height, "width": width}
|
||||
multiple = 32
|
||||
image_processor = ZoeDepthImageProcessor(
|
||||
do_pad=False, ensure_multiple_of=multiple, size=size, keep_aspect_ratio=False
|
||||
)
|
||||
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
self.assertEqual(list(pixel_values.shape), [1, 3, height, width])
|
||||
self.assertTrue(pixel_values.shape[2] % multiple == 0)
|
||||
self.assertTrue(pixel_values.shape[3] % multiple == 0)
|
||||
|
||||
def test_keep_aspect_ratio(self):
|
||||
# Test `keep_aspect_ratio=True` by turning off all other variables which affect the size
|
||||
height, width = 489, 640
|
||||
image = np.zeros((height, width, 3))
|
||||
|
||||
size = {"height": 512, "width": 512}
|
||||
image_processor = ZoeDepthImageProcessor(do_pad=False, keep_aspect_ratio=True, size=size, ensure_multiple_of=1)
|
||||
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
# As can be seen, the image is resized to the maximum size that fits in the specified size
|
||||
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 670])
|
||||
|
||||
# Test `keep_aspect_ratio=False` by turning off all other variables which affect the size
|
||||
image_processor = ZoeDepthImageProcessor(
|
||||
do_pad=False, keep_aspect_ratio=False, size=size, ensure_multiple_of=1
|
||||
)
|
||||
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
# As can be seen, the size is respected
|
||||
self.assertEqual(list(pixel_values.shape), [1, 3, size["height"], size["width"]])
|
||||
|
||||
# Test `keep_aspect_ratio=True` with `ensure_multiple_of` set
|
||||
image = np.zeros((489, 640, 3))
|
||||
|
||||
size = {"height": 511, "width": 511}
|
||||
multiple = 32
|
||||
image_processor = ZoeDepthImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=multiple)
|
||||
|
||||
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672])
|
||||
self.assertTrue(pixel_values.shape[2] % multiple == 0)
|
||||
self.assertTrue(pixel_values.shape[3] % multiple == 0)
|
257
tests/models/zoedepth/test_modeling_zoedepth.py
Normal file
257
tests/models/zoedepth/test_modeling_zoedepth.py
Normal file
@ -0,0 +1,257 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch ZoeDepth model."""
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import Dinov2Config, ZoeDepthConfig
|
||||
from transformers.file_utils import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import ZoeDepthForDepthEstimation
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import ZoeDepthImageProcessor
|
||||
|
||||
|
||||
class ZoeDepthModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=2,
|
||||
num_channels=3,
|
||||
image_size=32,
|
||||
patch_size=16,
|
||||
use_labels=True,
|
||||
num_labels=3,
|
||||
is_training=True,
|
||||
hidden_size=4,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
intermediate_size=8,
|
||||
out_features=["stage1", "stage2"],
|
||||
apply_layernorm=False,
|
||||
reshape_hidden_states=False,
|
||||
neck_hidden_sizes=[2, 2],
|
||||
fusion_hidden_size=6,
|
||||
bottleneck_features=6,
|
||||
num_out_features=[6, 6, 6, 6],
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.out_features = out_features
|
||||
self.apply_layernorm = apply_layernorm
|
||||
self.reshape_hidden_states = reshape_hidden_states
|
||||
self.use_labels = use_labels
|
||||
self.num_labels = num_labels
|
||||
self.is_training = is_training
|
||||
self.neck_hidden_sizes = neck_hidden_sizes
|
||||
self.fusion_hidden_size = fusion_hidden_size
|
||||
self.bottleneck_features = bottleneck_features
|
||||
self.num_out_features = num_out_features
|
||||
# ZoeDepth's sequence length
|
||||
self.seq_length = (self.image_size // self.patch_size) ** 2 + 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def get_config(self):
|
||||
return ZoeDepthConfig(
|
||||
backbone_config=self.get_backbone_config(),
|
||||
backbone=None,
|
||||
neck_hidden_sizes=self.neck_hidden_sizes,
|
||||
fusion_hidden_size=self.fusion_hidden_size,
|
||||
bottleneck_features=self.bottleneck_features,
|
||||
num_out_features=self.num_out_features,
|
||||
)
|
||||
|
||||
def get_backbone_config(self):
|
||||
return Dinov2Config(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
is_training=self.is_training,
|
||||
out_features=self.out_features,
|
||||
reshape_hidden_states=self.reshape_hidden_states,
|
||||
)
|
||||
|
||||
def create_and_check_for_depth_estimation(self, config, pixel_values, labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = ZoeDepthForDepthEstimation(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.predicted_depth.shape, (self.batch_size, self.image_size, self.image_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class ZoeDepthModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as ZoeDepth does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (ZoeDepthForDepthEstimation,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"depth-estimation": ZoeDepthForDepthEstimation} if is_torch_available() else {}
|
||||
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ZoeDepthModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=ZoeDepthConfig, has_text_modality=False, hidden_size=37, common_properties=[]
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip(reason="ZoeDepth with AutoBackbone does not have a base model and hence no input_embeddings")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="ZoeDepth with AutoBackbone does not have a base model and hence no input_embeddings")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
def test_for_depth_estimation(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_depth_estimation(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="ZoeDepth with AutoBackbone does not have a base model and hence no input_embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="ZoeDepth with AutoBackbone does not have a base model")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="ZoeDepth with AutoBackbone does not have a base model")
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="ZoeDepth does not support training yet")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="ZoeDepth does not support training yet")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="ZoeDepth does not support training yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="ZoeDepth does not support training yet")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "Intel/zoedepth-nyu"
|
||||
model = ZoeDepthForDepthEstimation.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
@slow
|
||||
class ZoeDepthModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_depth_estimation(self):
|
||||
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu")
|
||||
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu").to(torch_device)
|
||||
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
predicted_depth = outputs.predicted_depth
|
||||
|
||||
# verify the predicted depth
|
||||
expected_shape = torch.Size((1, 384, 512))
|
||||
self.assertEqual(predicted_depth.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[1.0020, 1.0219, 1.0389], [1.0349, 1.0816, 1.1000], [1.0576, 1.1094, 1.1249]],
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
def test_inference_depth_estimation_multiple_heads(self):
|
||||
image_processor = ZoeDepthImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti")
|
||||
model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti").to(torch_device)
|
||||
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
predicted_depth = outputs.predicted_depth
|
||||
|
||||
# verify the predicted depth
|
||||
expected_shape = torch.Size((1, 384, 512))
|
||||
self.assertEqual(predicted_depth.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[1.1571, 1.1438, 1.1783], [1.2163, 1.2036, 1.2320], [1.2688, 1.2461, 1.2734]],
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.predicted_depth[0, :3, :3], expected_slice, atol=1e-4))
|
Loading…
Reference in New Issue
Block a user