# DepthPro
PyTorch
## Overview The DepthPro model was proposed in [Depth Pro: Sharp Monocular Metric Depth in Less Than a Second](https://huggingface.co/papers/2410.02073) by Aleksei Bochkovskii, Amaël Delaunoy, Hugo Germain, Marcel Santos, Yichao Zhou, Stephan R. Richter, Vladlen Koltun. DepthPro is a foundation model for zero-shot metric monocular depth estimation, designed to generate high-resolution depth maps with remarkable sharpness and fine-grained details. It employs a multi-scale Vision Transformer (ViT)-based architecture, where images are downsampled, divided into patches, and processed using a shared Dinov2 encoder. The extracted patch-level features are merged, upsampled, and refined using a DPT-like fusion stage, enabling precise depth estimation. The abstract from the paper is the following: *We present a foundation model for zero-shot metric monocular depth estimation. Our model, Depth Pro, synthesizes high-resolution depth maps with unparalleled sharpness and high-frequency details. The predictions are metric, with absolute scale, without relying on the availability of metadata such as camera intrinsics. And the model is fast, producing a 2.25-megapixel depth map in 0.3 seconds on a standard GPU. These characteristics are enabled by a number of technical contributions, including an efficient multi-scale vision transformer for dense prediction, a training protocol that combines real and synthetic datasets to achieve high metric accuracy alongside fine boundary tracing, dedicated evaluation metrics for boundary accuracy in estimated depth maps, and state-of-the-art focal length estimation from a single image. Extensive experiments analyze specific design choices and demonstrate that Depth Pro outperforms prior work along multiple dimensions.* drawing DepthPro Outputs. Taken from the official code. This model was contributed by [geetu040](https://github.com/geetu040). The original code can be found [here](https://github.com/apple/ml-depth-pro). ## Usage Tips The DepthPro model processes an input image by first downsampling it at multiple scales and splitting each scaled version into patches. These patches are then encoded using a shared Vision Transformer (ViT)-based Dinov2 patch encoder, while the full image is processed by a separate image encoder. The extracted patch features are merged into feature maps, upsampled, and fused using a DPT-like decoder to generate the final depth estimation. If enabled, an additional Field of View (FOV) encoder processes the image for estimating the camera's field of view, aiding in depth accuracy. ```py >>> import requests >>> from PIL import Image >>> import torch >>> from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' >>> image = Image.open(requests.get(url, stream=True).raw) >>> image_processor = DepthProImageProcessorFast.from_pretrained("apple/DepthPro-hf") >>> model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device) >>> inputs = image_processor(images=image, return_tensors="pt").to(device) >>> with torch.no_grad(): ... outputs = model(**inputs) >>> post_processed_output = image_processor.post_process_depth_estimation( ... outputs, target_sizes=[(image.height, image.width)], ... ) >>> field_of_view = post_processed_output[0]["field_of_view"] >>> focal_length = post_processed_output[0]["focal_length"] >>> depth = post_processed_output[0]["predicted_depth"] >>> depth = (depth - depth.min()) / depth.max() >>> depth = depth * 255. >>> depth = depth.detach().cpu().numpy() >>> depth = Image.fromarray(depth.astype("uint8")) ``` ### Architecture and Configuration drawing DepthPro architecture. Taken from the original paper. The `DepthProForDepthEstimation` model uses a `DepthProEncoder`, for encoding the input image and a `FeatureFusionStage` for fusing the output features from encoder. The `DepthProEncoder` further uses two encoders: - `patch_encoder` - Input image is scaled with multiple ratios, as specified in the `scaled_images_ratios` configuration. - Each scaled image is split into smaller **patches** of size `patch_size` with overlapping areas determined by `scaled_images_overlap_ratios`. - These patches are processed by the **`patch_encoder`** - `image_encoder` - Input image is also rescaled to `patch_size` and processed by the **`image_encoder`** Both these encoders can be configured via `patch_model_config` and `image_model_config` respectively, both of which are separate `Dinov2Model` by default. Outputs from both encoders (`last_hidden_state`) and selected intermediate states (`hidden_states`) from **`patch_encoder`** are fused by a `DPT`-based `FeatureFusionStage` for depth estimation. ### Field-of-View (FOV) Prediction The network is supplemented with a focal length estimation head. A small convolutional head ingests frozen features from the depth estimation network and task-specific features from a separate ViT image encoder to predict the horizontal angular field-of-view. The `use_fov_model` parameter in `DepthProConfig` controls whether **FOV prediction** is enabled. By default, it is set to `False` to conserve memory and computation. When enabled, the **FOV encoder** is instantiated based on the `fov_model_config` parameter, which defaults to a `Dinov2Model`. The `use_fov_model` parameter can also be passed when initializing the `DepthProForDepthEstimation` model. The pretrained model at checkpoint `apple/DepthPro-hf` uses the FOV encoder. To use the pretrained-model without FOV encoder, set `use_fov_model=False` when loading the model, which saves computation. ```py >>> from transformers import DepthProForDepthEstimation >>> model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf", use_fov_model=False) ``` To instantiate a new model with FOV encoder, set `use_fov_model=True` in the config. ```py >>> from transformers import DepthProConfig, DepthProForDepthEstimation >>> config = DepthProConfig(use_fov_model=True) >>> model = DepthProForDepthEstimation(config) ``` Or set `use_fov_model=True` when initializing the model, which overrides the value in config. ```py >>> from transformers import DepthProConfig, DepthProForDepthEstimation >>> config = DepthProConfig() >>> model = DepthProForDepthEstimation(config, use_fov_model=True) ``` ### Using Scaled Dot Product Attention (SDPA) PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the [official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) page for more information. SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. ```py from transformers import DepthProForDepthEstimation model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf", attn_implementation="sdpa", torch_dtype=torch.float16) ``` For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-base-patch16-224` model, we saw the following speedups during inference. | Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) | |--------------|-------------------------------------------|-------------------------------------------|------------------------------| | 1 | 7 | 6 | 1.17 | | 2 | 8 | 6 | 1.33 | | 4 | 8 | 6 | 1.33 | | 8 | 8 | 6 | 1.33 | ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DepthPro: - Research Paper: [Depth Pro: Sharp Monocular Metric Depth in Less Than a Second](https://huggingface.co/papers/2410.02073) - Official Implementation: [apple/ml-depth-pro](https://github.com/apple/ml-depth-pro) - DepthPro Inference Notebook: [DepthPro Inference](https://github.com/qubvel/transformers-notebooks/blob/main/notebooks/DepthPro_inference.ipynb) - DepthPro for Super Resolution and Image Segmentation - Read blog on Medium: [Depth Pro: Beyond Depth](https://medium.com/@raoarmaghanshakir040/depth-pro-beyond-depth-9d822fc557ba) - Code on Github: [geetu040/depthpro-beyond-depth](https://github.com/geetu040/depthpro-beyond-depth) If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. ## DepthProConfig [[autodoc]] DepthProConfig ## DepthProImageProcessor [[autodoc]] DepthProImageProcessor - preprocess - post_process_depth_estimation ## DepthProImageProcessorFast [[autodoc]] DepthProImageProcessorFast - preprocess - post_process_depth_estimation ## DepthProModel [[autodoc]] DepthProModel - forward ## DepthProForDepthEstimation [[autodoc]] DepthProForDepthEstimation - forward