Add Prompt Depth Anything Model (#35401)

* add prompt depth anything model by modular transformer

* add prompt depth anything docs and imports

* update code style according transformers doc

* update code style: import order issue is fixed by custom_init_isort

* fix depth shape from B,1,H,W to B,H,W which is as the same as Depth Anything

* move prompt depth anything to vision models in _toctree.yml

* update backbone test; there is no need for resnet18 backbone test

* update init file & pass RUN_SLOW tests

* update len(prompt_depth) to prompt_depth.shape[0]

Co-authored-by: Joshua Lochner <admin@xenova.com>

* fix torch_int/model_doc

* fix typo

* update PromptDepthAnythingImageProcessor

* fix typo

* fix typo for prompt depth anything doc

* update promptda overview image link of huggingface repo

* fix some typos in promptda doc

* Update image processing to include pad_image, prompt depth position, and related explanations for better clarity and functionality.

* add copy disclaimer for prompt depth anything image processing

* fix some format typos in image processing and conversion scripts

* fix nn.ReLU(False) to nn.ReLU()

* rename residual layer as it's a sequential layer

* move size compute to a separate line/variable for easier debug in modular prompt depth anything

* fix modular format for prompt depth anything

* update modular prompt depth anything

* fix scale to meter and some internal funcs warp

* fix code style in image_processing_prompt_depth_anything.py

* fix issues in image_processing_prompt_depth_anything.py

* fix issues in image_processing_prompt_depth_anything.py

* fix issues in prompt depth anything

* update converting script similar to mllamma

* update testing for modeling prompt depth anything

* update testing for image_processing_prompt_depth_anything

* fix assertion in image_processing_prompt_depth_anything

* Update src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update docs/source/en/model_doc/prompt_depth_anything.md

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update docs/source/en/model_doc/prompt_depth_anything.md

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* update some testing

* fix testing

* fix

* add return doc for forward of prompt depth anything

* Update src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update tests/models/prompt_depth_anything/test_modeling_prompt_depth_anything.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* fix prompt depth order

* fix format for testing prompt depth anything

* fix minor issues in prompt depth anything doc

* fix format for modular prompt depth anything

* revert format for modular prompt depth anything

* revert format for modular prompt depth anything

* update format for modular prompt depth anything

* fix parallel testing errors

* fix doc for prompt depth anything

* Add header

* Fix imports

* Licence header

---------

Co-authored-by: Joshua Lochner <admin@xenova.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
Haotong LIN 2025-03-21 00:12:44 +08:00 committed by GitHub
parent 66291778dd
commit 6515c25953
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 2537 additions and 0 deletions

View File

@ -735,6 +735,8 @@
title: NAT
- local: model_doc/poolformer
title: PoolFormer
- local: model_doc/prompt_depth_anything
title: Prompt Depth Anything
- local: model_doc/pvt
title: Pyramid Vision Transformer (PVT)
- local: model_doc/pvt_v2

View File

@ -0,0 +1,96 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Prompt Depth Anything
## Overview
The Prompt Depth Anything model was introduced in [Prompting Depth Anything for 4K Resolution Accurate Metric Depth Estimation](https://arxiv.org/abs/2412.14015) by Haotong Lin, Sida Peng, Jingxiao Chen, Songyou Peng, Jiaming Sun, Minghuan Liu, Hujun Bao, Jiashi Feng, Xiaowei Zhou, Bingyi Kang.
The abstract from the paper is as follows:
*Prompts play a critical role in unleashing the power of language and vision foundation models for specific tasks. For the first time, we introduce prompting into depth foundation models, creating a new paradigm for metric depth estimation termed Prompt Depth Anything. Specifically, we use a low-cost LiDAR as the prompt to guide the Depth Anything model for accurate metric depth output, achieving up to 4K resolution. Our approach centers on a concise prompt fusion design that integrates the LiDAR at multiple scales within the depth decoder. To address training challenges posed by limited datasets containing both LiDAR depth and precise GT depth, we propose a scalable data pipeline that includes synthetic data LiDAR simulation and real data pseudo GT depth generation. Our approach sets new state-of-the-arts on the ARKitScenes and ScanNet++ datasets and benefits downstream applications, including 3D reconstruction and generalized robotic grasping.*
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/prompt_depth_anything_architecture.jpg"
alt="drawing" width="600"/>
<small> Prompt Depth Anything overview. Taken from the <a href="https://arxiv.org/pdf/2412.14015">original paper</a>.</small>
## Usage example
The Transformers library allows you to use the model with just a few lines of code:
```python
>>> import torch
>>> import requests
>>> import numpy as np
>>> from PIL import Image
>>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation
>>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
>>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
>>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true"
>>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw)
>>> # the prompt depth can be None, and the model will output a monocular relative depth.
>>> # prepare image for the model
>>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth)
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> # interpolate to original size
>>> post_processed_output = image_processor.post_process_depth_estimation(
... outputs,
... target_sizes=[(image.height, image.width)],
... )
>>> # visualize the prediction
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
>>> depth = predicted_depth * 1000
>>> depth = depth.detach().cpu().numpy()
>>> depth = Image.fromarray(depth.astype("uint16")) # mm
```
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Prompt Depth Anything.
- [Prompt Depth Anything Demo](https://huggingface.co/spaces/depth-anything/PromptDA)
- [Prompt Depth Anything Interactive Results](https://promptda.github.io/interactive.html)
If you are interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
## PromptDepthAnythingConfig
[[autodoc]] PromptDepthAnythingConfig
## PromptDepthAnythingForDepthEstimation
[[autodoc]] PromptDepthAnythingForDepthEstimation
- forward
## PromptDepthAnythingImageProcessor
[[autodoc]] PromptDepthAnythingImageProcessor
- preprocess
- post_process_depth_estimation

View File

@ -711,6 +711,7 @@ _import_structure = {
"models.plbart": ["PLBartConfig"],
"models.poolformer": ["PoolFormerConfig"],
"models.pop2piano": ["Pop2PianoConfig"],
"models.prompt_depth_anything": ["PromptDepthAnythingConfig"],
"models.prophetnet": [
"ProphetNetConfig",
"ProphetNetTokenizer",
@ -1299,6 +1300,7 @@ else:
_import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"])
_import_structure["models.pixtral"].append("PixtralImageProcessor")
_import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"])
_import_structure["models.prompt_depth_anything"].extend(["PromptDepthAnythingImageProcessor"])
_import_structure["models.pvt"].extend(["PvtImageProcessor"])
_import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"])
_import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"])
@ -3335,6 +3337,12 @@ else:
"Pop2PianoPreTrainedModel",
]
)
_import_structure["models.prompt_depth_anything"].extend(
[
"PromptDepthAnythingForDepthEstimation",
"PromptDepthAnythingPreTrainedModel",
]
)
_import_structure["models.prophetnet"].extend(
[
"ProphetNetDecoder",
@ -5921,6 +5929,7 @@ if TYPE_CHECKING:
from .models.pop2piano import (
Pop2PianoConfig,
)
from .models.prompt_depth_anything import PromptDepthAnythingConfig
from .models.prophetnet import (
ProphetNetConfig,
ProphetNetTokenizer,
@ -6530,6 +6539,7 @@ if TYPE_CHECKING:
PoolFormerFeatureExtractor,
PoolFormerImageProcessor,
)
from .models.prompt_depth_anything import PromptDepthAnythingImageProcessor
from .models.pvt import PvtImageProcessor
from .models.qwen2_vl import Qwen2VLImageProcessor
from .models.rt_detr import RTDetrImageProcessor
@ -8166,6 +8176,10 @@ if TYPE_CHECKING:
Pop2PianoForConditionalGeneration,
Pop2PianoPreTrainedModel,
)
from .models.prompt_depth_anything import (
PromptDepthAnythingForDepthEstimation,
PromptDepthAnythingPreTrainedModel,
)
from .models.prophetnet import (
ProphetNetDecoder,
ProphetNetEncoder,

View File

@ -219,6 +219,7 @@ from . import (
plbart,
poolformer,
pop2piano,
prompt_depth_anything,
prophetnet,
pvt,
pvt_v2,

View File

@ -241,6 +241,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("plbart", "PLBartConfig"),
("poolformer", "PoolFormerConfig"),
("pop2piano", "Pop2PianoConfig"),
("prompt_depth_anything", "PromptDepthAnythingConfig"),
("prophetnet", "ProphetNetConfig"),
("pvt", "PvtConfig"),
("pvt_v2", "PvtV2Config"),
@ -593,6 +594,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("plbart", "PLBart"),
("poolformer", "PoolFormer"),
("pop2piano", "Pop2Piano"),
("prompt_depth_anything", "PromptDepthAnything"),
("prophetnet", "ProphetNet"),
("pvt", "PVT"),
("pvt_v2", "PVTv2"),

View File

@ -127,6 +127,7 @@ else:
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
("poolformer", ("PoolFormerImageProcessor",)),
("prompt_depth_anything", ("PromptDepthAnythingImageProcessor",)),
("pvt", ("PvtImageProcessor",)),
("pvt_v2", ("PvtImageProcessor",)),
("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),

View File

@ -942,6 +942,7 @@ MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
("depth_pro", "DepthProForDepthEstimation"),
("dpt", "DPTForDepthEstimation"),
("glpn", "GLPNForDepthEstimation"),
("prompt_depth_anything", "PromptDepthAnythingForDepthEstimation"),
("zoedepth", "ZoeDepthForDepthEstimation"),
]
)

View File

@ -0,0 +1,31 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_prompt_depth_anything import PromptDepthAnythingConfig
from .image_processing_prompt_depth_anything import PromptDepthAnythingImageProcessor
from .modeling_prompt_depth_anything import (
PromptDepthAnythingForDepthEstimation,
PromptDepthAnythingPreTrainedModel,
)
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,171 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_prompt_depth_anything.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto.configuration_auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class PromptDepthAnythingConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PromptDepthAnythingModel`]. It is used to instantiate a PromptDepthAnything
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the PromptDepthAnything
[LiheYoung/depth-anything-small-hf](https://huggingface.co/LiheYoung/depth-anything-small-hf) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to
leverage the [`AutoBackbone`] API.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
API.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
patch_size (`int`, *optional*, defaults to 14):
The size of the patches to extract from the backbone features.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
reassemble_hidden_size (`int`, *optional*, defaults to 384):
The number of input channels of the reassemble layers.
reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`):
The up/downsampling factors of the reassemble layers.
neck_hidden_sizes (`List[str]`, *optional*, defaults to `[48, 96, 192, 384]`):
The hidden sizes to project to for the feature maps of the backbone.
fusion_hidden_size (`int`, *optional*, defaults to 64):
The number of channels before fusion.
head_in_index (`int`, *optional*, defaults to -1):
The index of the features to use in the depth estimation head.
head_hidden_size (`int`, *optional*, defaults to 32):
The number of output channels in the second convolution of the depth estimation head.
depth_estimation_type (`str`, *optional*, defaults to `"relative"`):
The type of depth estimation to use. Can be one of `["relative", "metric"]`.
max_depth (`float`, *optional*):
The maximum depth to use for the "metric" depth estimation head. 20 should be used for indoor models
and 80 for outdoor models. For "relative" depth estimation, this value is ignored.
Example:
```python
>>> from transformers import PromptDepthAnythingConfig, PromptDepthAnythingForDepthEstimation
>>> # Initializing a PromptDepthAnything small style configuration
>>> configuration = PromptDepthAnythingConfig()
>>> # Initializing a model from the PromptDepthAnything small style configuration
>>> model = PromptDepthAnythingForDepthEstimation(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "prompt_depth_anything"
def __init__(
self,
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
backbone_kwargs=None,
patch_size=14,
initializer_range=0.02,
reassemble_hidden_size=384,
reassemble_factors=[4, 2, 1, 0.5],
neck_hidden_sizes=[48, 96, 192, 384],
fusion_hidden_size=64,
head_in_index=-1,
head_hidden_size=32,
depth_estimation_type="relative",
max_depth=None,
**kwargs,
):
super().__init__(**kwargs)
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `Dinov2` backbone.")
backbone_config = CONFIG_MAPPING["dinov2"](
image_size=518,
hidden_size=384,
num_attention_heads=6,
out_indices=[9, 10, 11, 12],
apply_layernorm=True,
reshape_hidden_states=False,
)
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
self.reassemble_hidden_size = reassemble_hidden_size
self.patch_size = patch_size
self.initializer_range = initializer_range
self.reassemble_factors = reassemble_factors
self.neck_hidden_sizes = neck_hidden_sizes
self.fusion_hidden_size = fusion_hidden_size
self.head_in_index = head_in_index
self.head_hidden_size = head_hidden_size
if depth_estimation_type not in ["relative", "metric"]:
raise ValueError("depth_estimation_type must be one of ['relative', 'metric']")
self.depth_estimation_type = depth_estimation_type
self.max_depth = max_depth if max_depth else 1
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if output["backbone_config"] is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
__all__ = ["PromptDepthAnythingConfig"]

View File

@ -0,0 +1,292 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Prompt Depth Anything checkpoints from the original repository. URL:
https://github.com/DepthAnything/PromptDA"""
import argparse
import re
from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import (
Dinov2Config,
PromptDepthAnythingConfig,
PromptDepthAnythingForDepthEstimation,
PromptDepthAnythingImageProcessor,
)
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_dpt_config(model_name):
if "small" in model_name or "vits" in model_name:
out_indices = [3, 6, 9, 12]
backbone_config = Dinov2Config.from_pretrained(
"facebook/dinov2-small", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False
)
fusion_hidden_size = 64
neck_hidden_sizes = [48, 96, 192, 384]
elif "base" in model_name or "vitb" in model_name:
out_indices = [3, 6, 9, 12]
backbone_config = Dinov2Config.from_pretrained(
"facebook/dinov2-base", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False
)
fusion_hidden_size = 128
neck_hidden_sizes = [96, 192, 384, 768]
elif "large" in model_name or "vitl" in model_name:
out_indices = [5, 12, 18, 24]
backbone_config = Dinov2Config.from_pretrained(
"facebook/dinov2-large", out_indices=out_indices, apply_layernorm=True, reshape_hidden_states=False
)
fusion_hidden_size = 256
neck_hidden_sizes = [256, 512, 1024, 1024]
else:
raise NotImplementedError(f"Model not supported: {model_name}")
depth_estimation_type = "metric"
max_depth = None
config = PromptDepthAnythingConfig(
reassemble_hidden_size=backbone_config.hidden_size,
patch_size=backbone_config.patch_size,
backbone_config=backbone_config,
fusion_hidden_size=fusion_hidden_size,
neck_hidden_sizes=neck_hidden_sizes,
depth_estimation_type=depth_estimation_type,
max_depth=max_depth,
)
return config
def transform_qkv_weights(key, value, config):
if not key.startswith("qkv_transform"):
return value
layer_idx = int(key.split("_")[-1])
hidden_size = config.backbone_config.hidden_size
suffix = "bias" if "bias" in key else "weight"
return {
f"backbone.encoder.layer.{layer_idx}.attention.attention.query.{suffix}": value[:hidden_size],
f"backbone.encoder.layer.{layer_idx}.attention.attention.key.{suffix}": value[hidden_size : hidden_size * 2],
f"backbone.encoder.layer.{layer_idx}.attention.attention.value.{suffix}": value[-hidden_size:],
}
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
# Stem
r"pretrained.cls_token": r"backbone.embeddings.cls_token",
r"pretrained.mask_token": r"backbone.embeddings.mask_token",
r"pretrained.pos_embed": r"backbone.embeddings.position_embeddings",
r"pretrained.patch_embed.proj.(weight|bias)": r"backbone.embeddings.patch_embeddings.projection.\1",
# Backbone
r"pretrained.norm.(weight|bias)": r"backbone.layernorm.\1",
# Transformer layers
r"pretrained.blocks.(\d+).ls1.gamma": r"backbone.encoder.layer.\1.layer_scale1.lambda1",
r"pretrained.blocks.(\d+).ls2.gamma": r"backbone.encoder.layer.\1.layer_scale2.lambda1",
r"pretrained.blocks.(\d+).norm1.(weight|bias)": r"backbone.encoder.layer.\1.norm1.\2",
r"pretrained.blocks.(\d+).norm2.(weight|bias)": r"backbone.encoder.layer.\1.norm2.\2",
r"pretrained.blocks.(\d+).mlp.fc1.(weight|bias)": r"backbone.encoder.layer.\1.mlp.fc1.\2",
r"pretrained.blocks.(\d+).mlp.fc2.(weight|bias)": r"backbone.encoder.layer.\1.mlp.fc2.\2",
r"pretrained.blocks.(\d+).attn.proj.(weight|bias)": r"backbone.encoder.layer.\1.attention.output.dense.\2",
r"pretrained.blocks.(\d+).attn.qkv.(weight|bias)": r"qkv_transform_\2_\1",
# Neck
r"depth_head.projects.(\d+).(weight|bias)": r"neck.reassemble_stage.layers.\1.projection.\2",
r"depth_head.scratch.layer(\d+)_rn.weight": lambda m: f"neck.convs.{int(m.group(1))-1}.weight",
r"depth_head.resize_layers.(\d+).(weight|bias)": r"neck.reassemble_stage.layers.\1.resize.\2",
# Refinenet (with reversed indices)
r"depth_head.scratch.refinenet(\d+).out_conv.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.projection.{m.group(2)}",
r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv1.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.residual_layer1.convolution1.{m.group(2)}",
r"depth_head.scratch.refinenet(\d+).resConfUnit1.conv2.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.residual_layer1.convolution2.{m.group(2)}",
r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv1.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.residual_layer2.convolution1.{m.group(2)}",
r"depth_head.scratch.refinenet(\d+).resConfUnit2.conv2.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.residual_layer2.convolution2.{m.group(2)}",
r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.0.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.prompt_depth_layer.convolution1.{m.group(2)}",
r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.2.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.prompt_depth_layer.convolution2.{m.group(2)}",
r"depth_head.scratch.refinenet(\d+).resConfUnit_depth.4.(weight|bias)": lambda m: f"neck.fusion_stage.layers.{4-int(m.group(1))}.prompt_depth_layer.convolution3.{m.group(2)}",
# Head
r"depth_head.scratch.output_conv1.(weight|bias)": r"head.conv1.\1",
r"depth_head.scratch.output_conv2.0.(weight|bias)": r"head.conv2.\1",
r"depth_head.scratch.output_conv2.2.(weight|bias)": r"head.conv3.\1",
}
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
"""
Convert old state dict keys to new keys using regex patterns.
"""
output_dict = {}
if state_dict_keys is not None:
for old_key in state_dict_keys:
new_key = old_key
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
match = re.match(pattern, old_key)
if match:
if callable(replacement):
new_key = replacement(match)
else:
new_key = re.sub(pattern, replacement, old_key)
break
output_dict[old_key] = new_key
return output_dict
@torch.no_grad()
def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, verify_logits):
"""
Copy/paste/tweak model's weights to our DPT structure.
"""
# define DPT configuration
config = get_dpt_config(model_name)
model_name_to_repo = {
"prompt-depth-anything-vits": "depth-anything/prompt-depth-anything-vits",
"prompt-depth-anything-vits-transparent": "depth-anything/prompt-depth-anything-vits-transparent",
"prompt-depth-anything-vitl": "depth-anything/prompt-depth-anything-vitl",
}
# load original state_dict
repo_id = model_name_to_repo[model_name]
filename = name_to_checkpoint[model_name]
filepath = hf_hub_download(
repo_id=repo_id,
filename=f"{filename}",
)
state_dict = torch.load(filepath, map_location="cpu")["state_dict"]
state_dict = {key[9:]: state_dict[key] for key in state_dict}
# Convert state dict using mappings
key_mapping = convert_old_keys_to_new_keys(state_dict.keys())
new_state_dict = {}
for key, value in state_dict.items():
new_key = key_mapping[key]
transformed_value = transform_qkv_weights(new_key, value, config)
if isinstance(transformed_value, dict):
new_state_dict.update(transformed_value)
else:
new_state_dict[new_key] = transformed_value
# load HuggingFace model
model = PromptDepthAnythingForDepthEstimation(config)
model.load_state_dict(new_state_dict, strict=False)
model.eval()
processor = PromptDepthAnythingImageProcessor(
do_resize=True,
size=756,
ensure_multiple_of=14,
keep_aspect_ratio=True,
do_rescale=True,
do_normalize=True,
image_mean=[0.485, 0.456, 0.406],
image_std=[0.229, 0.224, 0.225],
)
url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
prompt_depth_url = (
"https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true"
)
prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw)
inputs = processor(image, return_tensors="pt", prompt_depth=prompt_depth)
# Verify forward pass
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
print("Shape of predicted depth:", predicted_depth.shape)
print("First values:", predicted_depth[0, :3, :3])
# assert logits
if verify_logits:
expected_shape = torch.Size([1, 756, 1008])
if model_name == "prompt-depth-anything-vits":
expected_slice = torch.tensor(
[[3.0100, 3.0016, 3.0219], [3.0046, 3.0137, 3.0275], [3.0083, 3.0191, 3.0292]]
)
elif model_name == "prompt-depth-anything-vits-transparent":
expected_slice = torch.tensor(
[[3.0058, 3.0397, 3.0460], [3.0314, 3.0393, 3.0504], [3.0326, 3.0465, 3.0545]]
)
elif model_name == "prompt-depth-anything-vitl":
expected_slice = torch.tensor(
[[3.1336, 3.1358, 3.1363], [3.1368, 3.1267, 3.1414], [3.1397, 3.1385, 3.1448]]
)
else:
raise ValueError("Not supported")
assert predicted_depth.shape == torch.Size(expected_shape)
assert torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=5e-3) # 5mm tolerance
print("Looks ok!")
if pytorch_dump_folder_path is not None:
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model and processor to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print("Pushing model and processor to hub...")
model.push_to_hub(repo_id=f"{model_name.title()}-hf")
processor.push_to_hub(repo_id=f"{model_name.title()}-hf")
name_to_checkpoint = {
"prompt-depth-anything-vits": "model.ckpt",
"prompt-depth-anything-vits-transparent": "model.ckpt",
"prompt-depth-anything-vitl": "model.ckpt",
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="prompt_depth_anything_vits",
type=str,
choices=name_to_checkpoint.keys(),
help="Name of the model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Whether to push the model to the hub after conversion.",
)
parser.add_argument(
"--verify_logits",
action="store_false",
required=False,
help="Whether to verify the logits after conversion.",
)
args = parser.parse_args()
convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits)

View File

@ -0,0 +1,504 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for PromptDepthAnything."""
import math
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
if TYPE_CHECKING:
from ...modeling_outputs import DepthEstimatorOutput
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import pad, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
is_torch_available,
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import (
TensorType,
filter_out_non_signature_kwargs,
logging,
requires_backends,
)
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
def _constrain_to_multiple_of(val, multiple, min_val=0, max_val=None):
x = round(val / multiple) * multiple
if max_val is not None and x > max_val:
x = math.floor(val / multiple) * multiple
if x < min_val:
x = math.ceil(val / multiple) * multiple
return x
def _get_resize_output_image_size(
input_image: np.ndarray,
output_size: Union[int, Iterable[int]],
keep_aspect_ratio: bool,
multiple: int,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[int, int]:
output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
input_height, input_width = get_image_size(input_image, input_data_format)
output_height, output_width = output_size
# determine new height and width
scale_height = output_height / input_height
scale_width = output_width / input_width
if keep_aspect_ratio:
# scale as little as possible
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
new_height = _constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
new_width = _constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
return (new_height, new_width)
class PromptDepthAnythingImageProcessor(BaseImageProcessor):
r"""
Constructs a PromptDepthAnything image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`.
size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`):
Size of the image after resizing. Can be overidden by `size` in `preprocess`.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
be overidden by `keep_aspect_ratio` in `preprocess`.
ensure_multiple_of (`int`, *optional*, defaults to 1):
If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden
by `ensure_multiple_of` in `preprocess`.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in
`preprocess`.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overidden by `rescale_factor` in `preprocess`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `False`):
Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in
combination with DPT.
size_divisor (`int`, *optional*):
If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the
DINOv2 paper, which uses the model in combination with DPT.
prompt_scale_to_meter (`float`, *optional*, defaults to 0.001):
Scale factor to convert the prompt depth to meters.
"""
model_input_names = ["pixel_values", "prompt_depth"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
keep_aspect_ratio: bool = False,
ensure_multiple_of: int = 1,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: bool = False,
size_divisor: int = None,
prompt_scale_to_meter: float = 0.001, # default unit is mm
**kwargs,
):
super().__init__(**kwargs)
size = size if size is not None else {"height": 384, "width": 384}
size = get_size_dict(size)
self.do_resize = do_resize
self.size = size
self.keep_aspect_ratio = keep_aspect_ratio
self.ensure_multiple_of = ensure_multiple_of
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_pad = do_pad
self.size_divisor = size_divisor
self.prompt_scale_to_meter = prompt_scale_to_meter
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
keep_aspect_ratio: bool = False,
ensure_multiple_of: int = 1,
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
set, the image is resized to a size that is a multiple of this value.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Target size of the output image.
keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
ensure_multiple_of (`int`, *optional*, defaults to 1):
The image is resized to a size that is a multiple of this value.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
output_size = _get_resize_output_image_size(
image,
output_size=(size["height"], size["width"]),
keep_aspect_ratio=keep_aspect_ratio,
multiple=ensure_multiple_of,
input_data_format=input_data_format,
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def pad_image(
self,
image: np.ndarray,
size_divisor: int,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Center pad an image to be a multiple of `multiple`.
Args:
image (`np.ndarray`):
Image to pad.
size_divisor (`int`):
The width and height of the image will be padded to a multiple of this number.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
def _get_pad(size, size_divisor):
new_size = math.ceil(size / size_divisor) * size_divisor
pad_size = new_size - size
pad_size_left = pad_size // 2
pad_size_right = pad_size - pad_size_left
return pad_size_left, pad_size_right
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
height, width = get_image_size(image, input_data_format)
pad_size_left, pad_size_right = _get_pad(height, size_divisor)
pad_size_top, pad_size_bottom = _get_pad(width, size_divisor)
padded_image = pad(
image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format
)
return padded_image
@filter_out_non_signature_kwargs()
def preprocess(
self,
images: ImageInput,
prompt_depth: Optional[ImageInput] = None,
do_resize: Optional[bool] = None,
size: Optional[int] = None,
keep_aspect_ratio: Optional[bool] = None,
ensure_multiple_of: Optional[int] = None,
resample: Optional[PILImageResampling] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
size_divisor: Optional[int] = None,
prompt_scale_to_meter: Optional[float] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
prompt_depth (`ImageInput`, *optional*):
Prompt depth to preprocess, which can be sparse depth obtained from multi-view geometry or
low-resolution depth from a depth sensor. Generally has shape (height, width), where height
and width can be smaller than those of the images. It's optional and can be None, which means no prompt depth
is used. If it is None, the output depth will be a monocular relative depth.
It is recommended to provide a prompt_scale_to_meter value, which is the scale factor to convert the prompt depth
to meters. This is useful when the prompt depth is not in meters.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. If `keep_aspect_ratio` is `True`, the image is resized to the largest
possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is set, the image is
resized to a size that is a multiple of this value.
keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`):
Whether to keep the aspect ratio of the image. If False, the image will be resized to (size, size). If
True, the image will be resized to keep the aspect ratio and the size will be the maximum possible.
ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`):
Ensure that the image size is a multiple of this value.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
prompt_scale_to_meter (`float`, *optional*, defaults to `self.prompt_scale_to_meter`):
Scale factor to convert the prompt depth to meters.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size)
keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio
ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_pad = do_pad if do_pad is not None else self.do_pad
size_divisor = size_divisor if size_divisor is not None else self.size_divisor
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=size_divisor,
do_resize=do_resize,
size=size,
resample=resample,
)
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
preprocessed_images = []
for image in images:
if do_resize:
image = self.resize(
image=image,
size=size,
resample=resample,
keep_aspect_ratio=keep_aspect_ratio,
ensure_multiple_of=ensure_multiple_of,
input_data_format=input_data_format,
)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)
if do_pad:
image = self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format)
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
preprocessed_images.append(image)
images = preprocessed_images
data = {"pixel_values": images}
if prompt_depth is not None:
# prompt_depth is a list of images with shape (height, width)
# we need to convert it to a list of images with shape (1, height, width)
prompt_depths = make_list_of_images(prompt_depth, expected_ndims=2)
# Validate prompt_depths has same length as images
if len(prompt_depths) != len(images):
raise ValueError(
f"Number of prompt depth images ({len(prompt_depths)}) does not match number of input images ({len(images)})"
)
if prompt_scale_to_meter is None:
prompt_scale_to_meter = self.prompt_scale_to_meter
processed_prompt_depths = []
for depth in prompt_depths:
depth = to_numpy_array(depth)
depth = depth * prompt_scale_to_meter
if depth.min() == depth.max():
# Prompt depth is invalid, min and max are the same.
# We can simply select one pixel and set it to a small value.
depth[0, 0] = depth[0, 0] + 1e-6
depth = depth[..., None].astype(np.float32)
depth = to_channel_dimension_format(depth, data_format, input_channel_dim=input_data_format)
processed_prompt_depths.append(depth)
prompt_depths = processed_prompt_depths
data["prompt_depth"] = prompt_depths
return BatchFeature(data=data, tensor_type=return_tensors)
# Copied from transformers.models.dpt.image_processing_dpt.DPTImageProcessor.post_process_depth_estimation with DPT->PromptDepthAnything
def post_process_depth_estimation(
self,
outputs: "DepthEstimatorOutput",
target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
) -> List[Dict[str, TensorType]]:
"""
Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
Only supports PyTorch.
Args:
outputs ([`DepthEstimatorOutput`]):
Raw outputs of the model.
target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized.
Returns:
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
predictions.
"""
requires_backends(self, "torch")
predicted_depth = outputs.predicted_depth
if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
)
results = []
target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
for depth, target_size in zip(predicted_depth, target_sizes):
if target_size is not None:
depth = torch.nn.functional.interpolate(
depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
).squeeze()
results.append({"predicted_depth": depth})
return results
__all__ = ["PromptDepthAnythingImageProcessor"]

View File

@ -0,0 +1,546 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_prompt_depth_anything.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.utils.generic import torch_int
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import DepthEstimatorOutput
from ...modeling_utils import PreTrainedModel
from ...utils.backbone_utils import load_backbone
from .configuration_prompt_depth_anything import PromptDepthAnythingConfig
_CONFIG_FOR_DOC = "PromptDepthAnythingConfig"
class PromptDepthAnythingLayer(nn.Module):
def __init__(self, config: PromptDepthAnythingConfig):
super().__init__()
self.convolution1 = nn.Conv2d(
1,
config.fusion_hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=True,
)
self.activation1 = nn.ReLU()
self.convolution2 = nn.Conv2d(
config.fusion_hidden_size,
config.fusion_hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=True,
)
self.activation2 = nn.ReLU()
self.convolution3 = nn.Conv2d(
config.fusion_hidden_size,
config.fusion_hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=True,
)
def forward(self, prompt_depth: torch.Tensor) -> torch.Tensor:
hidden_state = self.convolution1(prompt_depth)
hidden_state = self.activation1(hidden_state)
hidden_state = self.convolution2(hidden_state)
hidden_state = self.activation2(hidden_state)
hidden_state = self.convolution3(hidden_state)
return hidden_state
class PromptDepthAnythingPreActResidualLayer(nn.Module):
"""
ResidualConvUnit, pre-activate residual unit.
Args:
config (`[PromptDepthAnythingConfig]`):
Model configuration class defining the model architecture.
"""
def __init__(self, config):
super().__init__()
self.activation1 = nn.ReLU()
self.convolution1 = nn.Conv2d(
config.fusion_hidden_size,
config.fusion_hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=True,
)
self.activation2 = nn.ReLU()
self.convolution2 = nn.Conv2d(
config.fusion_hidden_size,
config.fusion_hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=True,
)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
residual = hidden_state
hidden_state = self.activation1(hidden_state)
hidden_state = self.convolution1(hidden_state)
hidden_state = self.activation2(hidden_state)
hidden_state = self.convolution2(hidden_state)
return hidden_state + residual
class PromptDepthAnythingFeatureFusionLayer(nn.Module):
"""Feature fusion layer, merges feature maps from different stages.
Args:
config (`[PromptDepthAnythingConfig]`):
Model configuration class defining the model architecture.
"""
def __init__(self, config: PromptDepthAnythingConfig):
super().__init__()
self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
self.residual_layer1 = PromptDepthAnythingPreActResidualLayer(config)
self.residual_layer2 = PromptDepthAnythingPreActResidualLayer(config)
self.prompt_depth_layer = PromptDepthAnythingLayer(config)
def forward(self, hidden_state, residual=None, size=None, prompt_depth=None):
if residual is not None:
if hidden_state.shape != residual.shape:
residual = nn.functional.interpolate(
residual, size=hidden_state.shape[2:], mode="bilinear", align_corners=False
)
hidden_state = hidden_state + self.residual_layer1(residual)
hidden_state = self.residual_layer2(hidden_state)
if prompt_depth is not None:
prompt_depth = nn.functional.interpolate(
prompt_depth, size=hidden_state.shape[2:], mode="bilinear", align_corners=False
)
res = self.prompt_depth_layer(prompt_depth)
hidden_state = hidden_state + res
modifier = {"scale_factor": 2} if size is None else {"size": size}
hidden_state = nn.functional.interpolate(
hidden_state,
**modifier,
mode="bilinear",
align_corners=True,
)
hidden_state = self.projection(hidden_state)
return hidden_state
class PromptDepthAnythingFeatureFusionStage(nn.Module):
def __init__(self, config):
super().__init__()
self.layers = nn.ModuleList()
for _ in range(len(config.neck_hidden_sizes)):
self.layers.append(PromptDepthAnythingFeatureFusionLayer(config))
def forward(self, hidden_states, size=None, prompt_depth=None):
# reversing the hidden_states, we start from the last
hidden_states = hidden_states[::-1]
fused_hidden_states = []
fused_hidden_state = None
for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)):
size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None
if fused_hidden_state is None:
# first layer only uses the last hidden_state
fused_hidden_state = layer(hidden_state, size=size, prompt_depth=prompt_depth)
else:
fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size, prompt_depth=prompt_depth)
fused_hidden_states.append(fused_hidden_state)
return fused_hidden_states
class PromptDepthAnythingDepthEstimationHead(nn.Module):
"""
Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
the predictions to the input resolution after the first convolutional layer (details can be found in the DPT paper's
supplementary material). The final activation function is either ReLU or Sigmoid, depending on the depth estimation
type (relative or metric). For metric depth estimation, the output is scaled by the maximum depth used during pretraining.
"""
def __init__(self, config):
super().__init__()
self.head_in_index = config.head_in_index
self.patch_size = config.patch_size
features = config.fusion_hidden_size
self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(features // 2, config.head_hidden_size, kernel_size=3, stride=1, padding=1)
self.activation1 = nn.ReLU()
self.conv3 = nn.Conv2d(config.head_hidden_size, 1, kernel_size=1, stride=1, padding=0)
if config.depth_estimation_type == "relative":
self.activation2 = nn.ReLU()
elif config.depth_estimation_type == "metric":
self.activation2 = nn.Sigmoid()
else:
raise ValueError(f"Unknown depth estimation type: {config.depth_estimation_type}")
self.max_depth = config.max_depth
def forward(self, hidden_states: List[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor:
hidden_states = hidden_states[-1]
predicted_depth = self.conv1(hidden_states)
target_height = torch_int(patch_height * self.patch_size)
target_width = torch_int(patch_width * self.patch_size)
predicted_depth = nn.functional.interpolate(
predicted_depth,
(target_height, target_width),
mode="bilinear",
align_corners=True,
)
predicted_depth = self.conv2(predicted_depth)
predicted_depth = self.activation1(predicted_depth)
predicted_depth = self.conv3(predicted_depth)
predicted_depth = self.activation2(predicted_depth)
# (batch_size, 1, height, width) -> (batch_size, height, width), which
# keeps the same behavior as Depth Anything v1 & v2
predicted_depth = predicted_depth.squeeze(dim=1)
return predicted_depth
class PromptDepthAnythingPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = PromptDepthAnythingConfig
base_model_prefix = "prompt_depth_anything"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class PromptDepthAnythingReassembleLayer(nn.Module):
def __init__(self, config: PromptDepthAnythingConfig, channels: int, factor: int):
super().__init__()
self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1)
# up/down sampling depending on factor
if factor > 1:
self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
elif factor == 1:
self.resize = nn.Identity()
elif factor < 1:
# so should downsample
stride = torch_int(1 / factor)
self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1)
def forward(self, hidden_state):
hidden_state = self.projection(hidden_state)
hidden_state = self.resize(hidden_state)
return hidden_state
class PromptDepthAnythingReassembleStage(nn.Module):
"""
This class reassembles the hidden states of the backbone into image-like feature representations at various
resolutions.
This happens in 3 stages:
1. Take the patch embeddings and reshape them to image-like feature representations.
2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
3. Resizing the spatial dimensions (height, width).
Args:
config (`[PromptDepthAnythingConfig]`):
Model configuration class defining the model architecture.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.layers = nn.ModuleList()
for channels, factor in zip(config.neck_hidden_sizes, config.reassemble_factors):
self.layers.append(PromptDepthAnythingReassembleLayer(config, channels=channels, factor=factor))
def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
"""
Args:
hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
List of hidden states from the backbone.
"""
out = []
for i, hidden_state in enumerate(hidden_states):
# reshape to (batch_size, num_channels, height, width)
hidden_state = hidden_state[:, 1:]
batch_size, _, num_channels = hidden_state.shape
hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
hidden_state = self.layers[i](hidden_state)
out.append(hidden_state)
return out
class PromptDepthAnythingNeck(nn.Module):
"""
PromptDepthAnythingNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
input and produces another list of tensors as output. For PromptDepthAnything, it includes 2 stages:
* PromptDepthAnythingReassembleStage
* PromptDepthAnythingFeatureFusionStage.
Args:
config (dict): config dict.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.reassemble_stage = PromptDepthAnythingReassembleStage(config)
self.convs = nn.ModuleList()
for channel in config.neck_hidden_sizes:
self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
# fusion
self.fusion_stage = PromptDepthAnythingFeatureFusionStage(config)
def forward(
self,
hidden_states: List[torch.Tensor],
patch_height: Optional[int] = None,
patch_width: Optional[int] = None,
prompt_depth: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""
Args:
hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
List of hidden states from the backbone.
"""
if not isinstance(hidden_states, (tuple, list)):
raise TypeError("hidden_states should be a tuple or list of tensors")
if len(hidden_states) != len(self.config.neck_hidden_sizes):
raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
# postprocess hidden states
hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
# fusion blocks
output = self.fusion_stage(features, prompt_depth=prompt_depth)
return output
PROMPT_DEPTH_ANYTHING_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`PromptDepthAnythingConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
prompt_depth (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
Prompt depth is the sparse or low-resolution depth obtained from multi-view geometry or a
low-resolution depth sensor. It generally has shape (height, width), where height
and width can be smaller than those of the images. It is optional and can be None, which means no prompt depth
will be used. If it is None, the output will be a monocular relative depth.
The values are recommended to be in meters, but this is not necessary.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"""
Prompt Depth Anything Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
""",
PROMPT_DEPTH_ANYTHING_START_DOCSTRING,
)
class PromptDepthAnythingForDepthEstimation(PromptDepthAnythingPreTrainedModel):
_no_split_modules = ["DPTViTEmbeddings"]
def __init__(self, config):
super().__init__(config)
self.backbone = load_backbone(config)
self.neck = PromptDepthAnythingNeck(config)
self.head = PromptDepthAnythingDepthEstimationHead(config)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: torch.FloatTensor,
prompt_depth: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth depth estimation maps for computing the loss.
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation
>>> import torch
>>> import numpy as np
>>> from PIL import Image
>>> import requests
>>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
>>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
>>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true"
>>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw)
>>> # prepare image for the model
>>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth)
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> # interpolate to original size
>>> post_processed_output = image_processor.post_process_depth_estimation(
... outputs,
... target_sizes=[(image.height, image.width)],
... )
>>> # visualize the prediction
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
>>> depth = predicted_depth * 1000.
>>> depth = depth.detach().cpu().numpy()
>>> depth = Image.fromarray(depth.astype("uint16")) # mm
```"""
loss = None
if labels is not None:
raise NotImplementedError("Training is not implemented yet")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
outputs = self.backbone.forward_with_filtered_kwargs(
pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
)
hidden_states = outputs.feature_maps
_, _, height, width = pixel_values.shape
patch_size = self.config.patch_size
patch_height = height // patch_size
patch_width = width // patch_size
if prompt_depth is not None:
# normalize prompt depth
batch_size = prompt_depth.shape[0]
depth_min = torch.min(prompt_depth.reshape(batch_size, -1), dim=1).values
depth_max = torch.max(prompt_depth.reshape(batch_size, -1), dim=1).values
depth_min, depth_max = depth_min.view(batch_size, 1, 1, 1), depth_max.view(batch_size, 1, 1, 1)
prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min)
# normalize done
hidden_states = self.neck(hidden_states, patch_height, patch_width, prompt_depth=prompt_depth)
predicted_depth = self.head(hidden_states, patch_height, patch_width)
if prompt_depth is not None:
# denormalize predicted depth
depth_min = depth_min.squeeze(1).to(predicted_depth.device)
depth_max = depth_max.squeeze(1).to(predicted_depth.device)
predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min
# denormalize done
if not return_dict:
if output_hidden_states:
output = (predicted_depth,) + outputs[1:]
else:
output = (predicted_depth,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return DepthEstimatorOutput(
loss=loss,
predicted_depth=predicted_depth,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
__all__ = ["PromptDepthAnythingForDepthEstimation", "PromptDepthAnythingPreTrainedModel"]

View File

@ -0,0 +1,391 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.models.depth_anything.configuration_depth_anything import DepthAnythingConfig
from transformers.models.depth_anything.modeling_depth_anything import (
DepthAnythingDepthEstimationHead,
DepthAnythingFeatureFusionLayer,
DepthAnythingFeatureFusionStage,
DepthAnythingForDepthEstimation,
DepthAnythingNeck,
DepthAnythingReassembleStage,
)
from transformers.utils.generic import torch_int
from ...file_utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import DepthEstimatorOutput
from ...modeling_utils import PreTrainedModel
_CONFIG_FOR_DOC = "PromptDepthAnythingConfig"
class PromptDepthAnythingConfig(DepthAnythingConfig):
model_type = "prompt_depth_anything"
class PromptDepthAnythingLayer(nn.Module):
def __init__(self, config: PromptDepthAnythingConfig):
super().__init__()
self.convolution1 = nn.Conv2d(
1,
config.fusion_hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=True,
)
self.activation1 = nn.ReLU()
self.convolution2 = nn.Conv2d(
config.fusion_hidden_size,
config.fusion_hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=True,
)
self.activation2 = nn.ReLU()
self.convolution3 = nn.Conv2d(
config.fusion_hidden_size,
config.fusion_hidden_size,
kernel_size=3,
stride=1,
padding=1,
bias=True,
)
def forward(self, prompt_depth: torch.Tensor) -> torch.Tensor:
hidden_state = self.convolution1(prompt_depth)
hidden_state = self.activation1(hidden_state)
hidden_state = self.convolution2(hidden_state)
hidden_state = self.activation2(hidden_state)
hidden_state = self.convolution3(hidden_state)
return hidden_state
class PromptDepthAnythingFeatureFusionLayer(DepthAnythingFeatureFusionLayer):
def __init__(self, config: PromptDepthAnythingConfig):
super().__init__(config)
self.prompt_depth_layer = PromptDepthAnythingLayer(config)
def forward(self, hidden_state, residual=None, size=None, prompt_depth=None):
if residual is not None:
if hidden_state.shape != residual.shape:
residual = nn.functional.interpolate(
residual, size=hidden_state.shape[2:], mode="bilinear", align_corners=False
)
hidden_state = hidden_state + self.residual_layer1(residual)
hidden_state = self.residual_layer2(hidden_state)
if prompt_depth is not None:
prompt_depth = nn.functional.interpolate(
prompt_depth, size=hidden_state.shape[2:], mode="bilinear", align_corners=False
)
res = self.prompt_depth_layer(prompt_depth)
hidden_state = hidden_state + res
modifier = {"scale_factor": 2} if size is None else {"size": size}
hidden_state = nn.functional.interpolate(
hidden_state,
**modifier,
mode="bilinear",
align_corners=True,
)
hidden_state = self.projection(hidden_state)
return hidden_state
class PromptDepthAnythingFeatureFusionStage(DepthAnythingFeatureFusionStage):
def forward(self, hidden_states, size=None, prompt_depth=None):
# reversing the hidden_states, we start from the last
hidden_states = hidden_states[::-1]
fused_hidden_states = []
fused_hidden_state = None
for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)):
size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None
if fused_hidden_state is None:
# first layer only uses the last hidden_state
fused_hidden_state = layer(hidden_state, size=size, prompt_depth=prompt_depth)
else:
fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size, prompt_depth=prompt_depth)
fused_hidden_states.append(fused_hidden_state)
return fused_hidden_states
class PromptDepthAnythingDepthEstimationHead(DepthAnythingDepthEstimationHead):
def forward(self, hidden_states: List[torch.Tensor], patch_height: int, patch_width: int) -> torch.Tensor:
hidden_states = hidden_states[-1]
predicted_depth = self.conv1(hidden_states)
target_height = torch_int(patch_height * self.patch_size)
target_width = torch_int(patch_width * self.patch_size)
predicted_depth = nn.functional.interpolate(
predicted_depth,
(target_height, target_width),
mode="bilinear",
align_corners=True,
)
predicted_depth = self.conv2(predicted_depth)
predicted_depth = self.activation1(predicted_depth)
predicted_depth = self.conv3(predicted_depth)
predicted_depth = self.activation2(predicted_depth)
# (batch_size, 1, height, width) -> (batch_size, height, width), which
# keeps the same behavior as Depth Anything v1 & v2
predicted_depth = predicted_depth.squeeze(dim=1)
return predicted_depth
PROMPT_DEPTH_ANYTHING_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`PromptDepthAnythingConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
prompt_depth (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
Prompt depth is the sparse or low-resolution depth obtained from multi-view geometry or a
low-resolution depth sensor. It generally has shape (height, width), where height
and width can be smaller than those of the images. It is optional and can be None, which means no prompt depth
will be used. If it is None, the output will be a monocular relative depth.
The values are recommended to be in meters, but this is not necessary.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
class PromptDepthAnythingPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = PromptDepthAnythingConfig
base_model_prefix = "prompt_depth_anything"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class PromptDepthAnythingReassembleLayer(nn.Module):
def __init__(self, config: PromptDepthAnythingConfig, channels: int, factor: int):
super().__init__()
self.projection = nn.Conv2d(in_channels=config.reassemble_hidden_size, out_channels=channels, kernel_size=1)
# up/down sampling depending on factor
if factor > 1:
self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
elif factor == 1:
self.resize = nn.Identity()
elif factor < 1:
# so should downsample
stride = torch_int(1 / factor)
self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1)
def forward(self, hidden_state):
hidden_state = self.projection(hidden_state)
hidden_state = self.resize(hidden_state)
return hidden_state
class PromptDepthAnythingReassembleStage(DepthAnythingReassembleStage):
pass
class PromptDepthAnythingNeck(DepthAnythingNeck):
def forward(
self,
hidden_states: List[torch.Tensor],
patch_height: Optional[int] = None,
patch_width: Optional[int] = None,
prompt_depth: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""
Args:
hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
List of hidden states from the backbone.
"""
if not isinstance(hidden_states, (tuple, list)):
raise TypeError("hidden_states should be a tuple or list of tensors")
if len(hidden_states) != len(self.config.neck_hidden_sizes):
raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
# postprocess hidden states
hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
# fusion blocks
output = self.fusion_stage(features, prompt_depth=prompt_depth)
return output
@add_start_docstrings(
"""
Prompt Depth Anything Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
""",
PROMPT_DEPTH_ANYTHING_START_DOCSTRING,
)
class PromptDepthAnythingForDepthEstimation(DepthAnythingForDepthEstimation):
@add_start_docstrings_to_model_forward(PROMPT_DEPTH_ANYTHING_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: torch.FloatTensor,
prompt_depth: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]:
r"""
```python
>>> from transformers import AutoImageProcessor, AutoModelForDepthEstimation
>>> import torch
>>> import numpy as np
>>> from PIL import Image
>>> import requests
>>> url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
>>> model = AutoModelForDepthEstimation.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
>>> prompt_depth_url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true"
>>> prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw)
>>> # prepare image for the model
>>> inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth)
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> # interpolate to original size
>>> post_processed_output = image_processor.post_process_depth_estimation(
... outputs,
... target_sizes=[(image.height, image.width)],
... )
>>> # visualize the prediction
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
>>> depth = predicted_depth * 1000.
>>> depth = depth.detach().cpu().numpy()
>>> depth = Image.fromarray(depth.astype("uint16")) # mm
```"""
loss = None
if labels is not None:
raise NotImplementedError("Training is not implemented yet")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
outputs = self.backbone.forward_with_filtered_kwargs(
pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
)
hidden_states = outputs.feature_maps
_, _, height, width = pixel_values.shape
patch_size = self.config.patch_size
patch_height = height // patch_size
patch_width = width // patch_size
if prompt_depth is not None:
# normalize prompt depth
batch_size = prompt_depth.shape[0]
depth_min = torch.min(prompt_depth.reshape(batch_size, -1), dim=1).values
depth_max = torch.max(prompt_depth.reshape(batch_size, -1), dim=1).values
depth_min, depth_max = depth_min.view(batch_size, 1, 1, 1), depth_max.view(batch_size, 1, 1, 1)
prompt_depth = (prompt_depth - depth_min) / (depth_max - depth_min)
# normalize done
hidden_states = self.neck(hidden_states, patch_height, patch_width, prompt_depth=prompt_depth)
predicted_depth = self.head(hidden_states, patch_height, patch_width)
if prompt_depth is not None:
# denormalize predicted depth
depth_min = depth_min.squeeze(1).to(predicted_depth.device)
depth_max = depth_max.squeeze(1).to(predicted_depth.device)
predicted_depth = predicted_depth * (depth_max - depth_min) + depth_min
# denormalize done
if not return_dict:
if output_hidden_states:
output = (predicted_depth,) + outputs[1:]
else:
output = (predicted_depth,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return DepthEstimatorOutput(
loss=loss,
predicted_depth=predicted_depth,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
__all__ = [
"PromptDepthAnythingConfig",
"PromptDepthAnythingForDepthEstimation",
"PromptDepthAnythingPreTrainedModel",
]

View File

@ -7878,6 +7878,20 @@ class Pop2PianoPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class PromptDepthAnythingForDepthEstimation(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PromptDepthAnythingPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ProphetNetDecoder(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -590,6 +590,13 @@ class PoolFormerImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class PromptDepthAnythingImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class PvtImageProcessor(metaclass=DummyObject):
_backends = ["vision"]

View File

@ -0,0 +1,139 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.file_utils import is_vision_available
from transformers.testing_utils import require_torch, require_vision
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_vision_available():
from transformers import PromptDepthAnythingImageProcessor
class PromptDepthAnythingImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_resize=True,
size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
super().__init__()
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
def prepare_image_processor_dict(self):
return {
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_normalize": self.do_normalize,
"do_resize": self.do_resize,
"size": self.size,
}
def expected_output_image_shape(self, images):
return self.num_channels, self.size["height"], self.size["width"]
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
class PromptDepthAnythingImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = PromptDepthAnythingImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = PromptDepthAnythingImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "size_divisor"))
self.assertTrue(hasattr(image_processing, "prompt_scale_to_meter"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
def test_keep_aspect_ratio(self):
size = {"height": 512, "width": 512}
image_processor = PromptDepthAnythingImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=32)
image = np.zeros((489, 640, 3))
pixel_values = image_processor(image, return_tensors="pt").pixel_values
self.assertEqual(list(pixel_values.shape), [1, 3, 512, 672])
def test_prompt_depth_processing(self):
size = {"height": 756, "width": 756}
image_processor = PromptDepthAnythingImageProcessor(size=size, keep_aspect_ratio=True, ensure_multiple_of=32)
image = np.zeros((756, 1008, 3))
prompt_depth = np.random.random((192, 256))
outputs = image_processor(image, prompt_depth=prompt_depth, return_tensors="pt")
pixel_values = outputs.pixel_values
prompt_depth_values = outputs.prompt_depth
self.assertEqual(list(pixel_values.shape), [1, 3, 768, 1024])
self.assertEqual(list(prompt_depth_values.shape), [1, 1, 192, 256])

View File

@ -0,0 +1,325 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Prompt Depth Anything model."""
import unittest
import requests
from transformers import Dinov2Config, PromptDepthAnythingConfig
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import PromptDepthAnythingForDepthEstimation
if is_vision_available():
from PIL import Image
from transformers import AutoImageProcessor
class PromptDepthAnythingModelTester:
def __init__(
self,
parent,
batch_size=2,
num_channels=3,
image_size=32,
patch_size=16,
use_labels=True,
num_labels=3,
is_training=True,
hidden_size=4,
num_hidden_layers=2,
num_attention_heads=2,
intermediate_size=8,
out_features=["stage1", "stage2"],
apply_layernorm=False,
reshape_hidden_states=False,
neck_hidden_sizes=[2, 2],
fusion_hidden_size=6,
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.patch_size = patch_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.out_features = out_features
self.apply_layernorm = apply_layernorm
self.reshape_hidden_states = reshape_hidden_states
self.use_labels = use_labels
self.num_labels = num_labels
self.is_training = is_training
self.neck_hidden_sizes = neck_hidden_sizes
self.fusion_hidden_size = fusion_hidden_size
self.seq_length = (self.image_size // self.patch_size) ** 2 + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
prompt_depth = floats_tensor([self.batch_size, 1, self.image_size // 4, self.image_size // 4])
config = self.get_config()
return config, pixel_values, labels, prompt_depth
def get_config(self):
return PromptDepthAnythingConfig(
backbone_config=self.get_backbone_config(),
reassemble_hidden_size=self.hidden_size,
patch_size=self.patch_size,
neck_hidden_sizes=self.neck_hidden_sizes,
fusion_hidden_size=self.fusion_hidden_size,
)
def get_backbone_config(self):
return Dinov2Config(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
is_training=self.is_training,
out_features=self.out_features,
reshape_hidden_states=self.reshape_hidden_states,
)
def create_and_check_for_depth_estimation(self, config, pixel_values, labels, prompt_depth):
config.num_labels = self.num_labels
model = PromptDepthAnythingForDepthEstimation(config)
model.to(torch_device)
model.eval()
result = model(pixel_values, prompt_depth=prompt_depth)
self.parent.assertEqual(result.predicted_depth.shape, (self.batch_size, self.image_size, self.image_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels, prompt_depth = config_and_inputs
inputs_dict = {"pixel_values": pixel_values, "prompt_depth": prompt_depth}
return config, inputs_dict
@require_torch
class PromptDepthAnythingModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as Prompt Depth Anything does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (PromptDepthAnythingForDepthEstimation,) if is_torch_available() else ()
pipeline_model_mapping = (
{"depth-estimation": PromptDepthAnythingForDepthEstimation} if is_torch_available() else {}
)
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = PromptDepthAnythingModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=PromptDepthAnythingConfig,
has_text_modality=False,
hidden_size=37,
common_properties=["patch_size"],
)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(
reason="Prompt Depth Anything with AutoBackbone does not have a base model and hence no input_embeddings"
)
def test_inputs_embeds(self):
pass
def test_for_depth_estimation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_depth_estimation(*config_and_inputs)
@unittest.skip(reason="Prompt Depth Anything does not support training yet")
def test_training(self):
pass
@unittest.skip(reason="Prompt Depth Anything does not support training yet")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="Prompt Depth Anything with AutoBackbone does not have a base model and hence no input_embeddings"
)
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="Prompt Depth Anything with AutoBackbone does not have a base model")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="Prompt Depth Anything with AutoBackbone does not have a base model")
def test_save_load_fast_init_to_base(self):
pass
@unittest.skip(
reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@slow
def test_model_from_pretrained(self):
model_name = "depth-anything/prompt-depth-anything-vits-hf"
model = PromptDepthAnythingForDepthEstimation.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_backbone_selection(self):
def _validate_backbone_init():
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
self.assertEqual(len(model.backbone.out_indices), 2)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.backbone = "facebook/dinov2-small"
config.use_pretrained_backbone = True
config.use_timm_backbone = False
config.backbone_config = None
config.backbone_kwargs = {"out_indices": [-2, -1]}
_validate_backbone_init()
def prepare_img():
url = "https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/image.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
return image
def prepare_prompt_depth():
prompt_depth_url = (
"https://github.com/DepthAnything/PromptDA/blob/main/assets/example_images/arkit_depth.png?raw=true"
)
prompt_depth = Image.open(requests.get(prompt_depth_url, stream=True).raw)
return prompt_depth
@require_torch
@require_vision
@slow
class PromptDepthAnythingModelIntegrationTest(unittest.TestCase):
def test_inference_wo_prompt_depth(self):
image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
model = PromptDepthAnythingForDepthEstimation.from_pretrained(
"depth-anything/prompt-depth-anything-vits-hf"
).to(torch_device)
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
expected_shape = torch.Size([1, 756, 1008])
self.assertEqual(predicted_depth.shape, expected_shape)
expected_slice = torch.tensor(
[[0.5029, 0.5120, 0.5176], [0.4998, 0.5147, 0.5197], [0.4973, 0.5201, 0.5241]]
).to(torch_device)
self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-3))
def test_inference(self):
image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
model = PromptDepthAnythingForDepthEstimation.from_pretrained(
"depth-anything/prompt-depth-anything-vits-hf"
).to(torch_device)
image = prepare_img()
prompt_depth = prepare_prompt_depth()
inputs = image_processor(images=image, return_tensors="pt", prompt_depth=prompt_depth).to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
expected_shape = torch.Size([1, 756, 1008])
self.assertEqual(predicted_depth.shape, expected_shape)
expected_slice = torch.tensor(
[[3.0100, 3.0016, 3.0219], [3.0046, 3.0137, 3.0275], [3.0083, 3.0191, 3.0292]]
).to(torch_device)
self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-3))
def test_export(self):
for strict in [True, False]:
with self.subTest(strict=strict):
if not is_torch_greater_or_equal_than_2_4:
self.skipTest(reason="This test requires torch >= 2.4 to run.")
model = (
PromptDepthAnythingForDepthEstimation.from_pretrained(
"depth-anything/prompt-depth-anything-vits-hf"
)
.to(torch_device)
.eval()
)
image_processor = AutoImageProcessor.from_pretrained("depth-anything/prompt-depth-anything-vits-hf")
image = prepare_img()
prompt_depth = prepare_prompt_depth()
inputs = image_processor(images=image, prompt_depth=prompt_depth, return_tensors="pt").to(torch_device)
exported_program = torch.export.export(
model,
args=(inputs["pixel_values"], inputs["prompt_depth"]),
strict=strict,
)
with torch.no_grad():
eager_outputs = model(**inputs)
exported_outputs = exported_program.module().forward(
inputs["pixel_values"], inputs["prompt_depth"]
)
self.assertEqual(eager_outputs.predicted_depth.shape, exported_outputs.predicted_depth.shape)
self.assertTrue(
torch.allclose(eager_outputs.predicted_depth, exported_outputs.predicted_depth, atol=1e-4)
)