This commit is contained in:
Jade Choghari 2025-07-02 18:26:11 -04:00 committed by GitHub
commit 8e5a1fc0a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 2107 additions and 1 deletions

View File

@ -733,6 +733,8 @@
title: DiT title: DiT
- local: model_doc/dpt - local: model_doc/dpt
title: DPT title: DPT
- local: model_doc/dust3r
title: Dust3R
- local: model_doc/efficientformer - local: model_doc/efficientformer
title: EfficientFormer title: EfficientFormer
- local: model_doc/efficientnet - local: model_doc/efficientnet
@ -1144,4 +1146,3 @@
title: Environment Variables title: Environment Variables
title: Reference title: Reference
title: API title: API

View File

@ -0,0 +1,89 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>
# Dust3R
## Overview
The Dust3R model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>
The abstract from the paper is the following:
*<INSERT PAPER ABSTRACT HERE>*
Tips:
<INSERT TIPS ABOUT MODEL HERE>
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
## Dust3RConfig
[[autodoc]] Dust3RConfig
## Dust3RImageProcessor
[[autodoc]] Dust3RImageProcessor
- preprocess
## Dust3RModel
[[autodoc]] Dust3RModel
- forward
## Dust3RForMaskedImageModeling
[[autodoc]] Dust3RForMaskedImageModeling
- forward
## Dust3RForImageClassification
[[autodoc]] Dust3RForImageClassification
- forward
## TFDust3RModel
[[autodoc]] TFDust3RModel
- call
## TFDust3RForImageClassification
[[autodoc]] TFDust3RForImageClassification
- call
## FlaxVitModel
[[autodoc]] FlaxDust3RModel
- __call__
## FlaxDust3RForImageClassification
[[autodoc]] FlaxDust3RForImageClassification
- __call__

View File

@ -371,6 +371,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("vision-text-dual-encoder", "VisionTextDualEncoderConfig"), ("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
("visual_bert", "VisualBertConfig"), ("visual_bert", "VisualBertConfig"),
("vit", "ViTConfig"), ("vit", "ViTConfig"),
("dust3r", "Dust3RConfig"),
("vit_hybrid", "ViTHybridConfig"), ("vit_hybrid", "ViTHybridConfig"),
("vit_mae", "ViTMAEConfig"), ("vit_mae", "ViTMAEConfig"),
("vit_msn", "ViTMSNConfig"), ("vit_msn", "ViTMSNConfig"),
@ -772,6 +773,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("vision-text-dual-encoder", "VisionTextDualEncoder"), ("vision-text-dual-encoder", "VisionTextDualEncoder"),
("visual_bert", "VisualBERT"), ("visual_bert", "VisualBERT"),
("vit", "ViT"), ("vit", "ViT"),
("dust3r", "Dust3R"),
("vit_hybrid", "ViT Hybrid"), ("vit_hybrid", "ViT Hybrid"),
("vit_mae", "ViTMAE"), ("vit_mae", "ViTMAE"),
("vit_msn", "ViTMSN"), ("vit_msn", "ViTMSN"),

View File

@ -169,6 +169,7 @@ else:
("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")), ("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")),
("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("vit", ("ViTImageProcessor", "ViTImageProcessorFast")), ("vit", ("ViTImageProcessor", "ViTImageProcessorFast")),
("dust3r", ("Dust3RImageProcessor", "Dust3RImageProcessorFast")),
("vit_hybrid", ("ViTHybridImageProcessor",)), ("vit_hybrid", ("ViTHybridImageProcessor",)),
("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast")), ("vit_mae", ("ViTImageProcessor", "ViTImageProcessorFast")),
("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast")), ("vit_msn", ("ViTImageProcessor", "ViTImageProcessorFast")),

View File

@ -344,6 +344,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("vision-text-dual-encoder", "VisionTextDualEncoderModel"), ("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
("visual_bert", "VisualBertModel"), ("visual_bert", "VisualBertModel"),
("vit", "ViTModel"), ("vit", "ViTModel"),
("dust3r", "Dust3RModel"),
("vit_hybrid", "ViTHybridModel"), ("vit_hybrid", "ViTHybridModel"),
("vit_mae", "ViTMAEModel"), ("vit_mae", "ViTMAEModel"),
("vit_msn", "ViTMSNModel"), ("vit_msn", "ViTMSNModel"),
@ -725,6 +726,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
("van", "VanModel"), ("van", "VanModel"),
("videomae", "VideoMAEModel"), ("videomae", "VideoMAEModel"),
("vit", "ViTModel"), ("vit", "ViTModel"),
("dust3r", "Dust3RModel"),
("vit_hybrid", "ViTHybridModel"), ("vit_hybrid", "ViTHybridModel"),
("vit_mae", "ViTMAEModel"), ("vit_mae", "ViTMAEModel"),
("vit_msn", "ViTMSNModel"), ("vit_msn", "ViTMSNModel"),
@ -741,6 +743,7 @@ MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
("swin", "SwinForMaskedImageModeling"), ("swin", "SwinForMaskedImageModeling"),
("swinv2", "Swinv2ForMaskedImageModeling"), ("swinv2", "Swinv2ForMaskedImageModeling"),
("vit", "ViTForMaskedImageModeling"), ("vit", "ViTForMaskedImageModeling"),
("dust3r", "Dust3RForMaskedImageModeling"),
] ]
) )
@ -816,6 +819,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("timm_wrapper", "TimmWrapperForImageClassification"), ("timm_wrapper", "TimmWrapperForImageClassification"),
("van", "VanForImageClassification"), ("van", "VanForImageClassification"),
("vit", "ViTForImageClassification"), ("vit", "ViTForImageClassification"),
("dust3r", "Dust3RForImageClassification"),
("vit_hybrid", "ViTHybridForImageClassification"), ("vit_hybrid", "ViTHybridForImageClassification"),
("vit_msn", "ViTMSNForImageClassification"), ("vit_msn", "ViTMSNForImageClassification"),
] ]

View File

@ -60,6 +60,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("t5", "FlaxT5Model"), ("t5", "FlaxT5Model"),
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("vit", "FlaxViTModel"), ("vit", "FlaxViTModel"),
("dust3r", "FlaxDust3RModel"),
("wav2vec2", "FlaxWav2Vec2Model"), ("wav2vec2", "FlaxWav2Vec2Model"),
("whisper", "FlaxWhisperModel"), ("whisper", "FlaxWhisperModel"),
("xglm", "FlaxXGLMModel"), ("xglm", "FlaxXGLMModel"),
@ -129,6 +130,7 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("regnet", "FlaxRegNetForImageClassification"), ("regnet", "FlaxRegNetForImageClassification"),
("resnet", "FlaxResNetForImageClassification"), ("resnet", "FlaxResNetForImageClassification"),
("vit", "FlaxViTForImageClassification"), ("vit", "FlaxViTForImageClassification"),
("dust3r", "FlaxDust3RForImageClassification"),
] ]
) )

View File

@ -90,6 +90,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("transfo-xl", "TFTransfoXLModel"), ("transfo-xl", "TFTransfoXLModel"),
("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"), ("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"),
("vit", "TFViTModel"), ("vit", "TFViTModel"),
("dust3r", "TFDust3RModel"),
("vit_mae", "TFViTMAEModel"), ("vit_mae", "TFViTMAEModel"),
("wav2vec2", "TFWav2Vec2Model"), ("wav2vec2", "TFWav2Vec2Model"),
("whisper", "TFWhisperModel"), ("whisper", "TFWhisperModel"),
@ -221,6 +222,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("swiftformer", "TFSwiftFormerForImageClassification"), ("swiftformer", "TFSwiftFormerForImageClassification"),
("swin", "TFSwinForImageClassification"), ("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"), ("vit", "TFViTForImageClassification"),
("dust3r", "TFDust3RForImageClassification"),
] ]
) )

View File

@ -0,0 +1,32 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_dust3r import *
from .feature_extraction_dust3r import *
from .image_processing_dust3r import *
from .image_processing_dust3r_fast import *
from .modeling_flax_dust3r import *
from .modeling_tf_dust3r import *
from .modeling_dust3r import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,151 @@
# coding=utf-8
# Copyright 2025 Google AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dust3R model configuration"""
from collections import OrderedDict
from collections.abc import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class Dust3RConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Dust3RModel`]. It is used to instantiate an Dust3R
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Dust3R
[google/dust3r-base-patch16-224](https://huggingface.co/google/dust3r-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
encoder_stride (`int`, *optional*, defaults to 16):
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
pooler_output_size (`int`, *optional*):
Dimensionality of the pooler layer. If None, defaults to `hidden_size`.
pooler_act (`str`, *optional*, defaults to `"tanh"`):
The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and
Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are
supported for Tensorflow.
Example:
```python
>>> from transformers import Dust3RConfig, Dust3RModel
>>> # Initializing a Dust3R dust3r-base-patch16-224 style configuration
>>> configuration = Dust3RConfig()
>>> # Initializing a model (with random weights) from the dust3r-base-patch16-224 style configuration
>>> model = Dust3RModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "dust3r"
def __init__(
self,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
layer_norm_eps=1e-12,
image_size=224,
patch_size=16,
num_channels=3,
qkv_bias=True,
encoder_stride=16,
pooler_output_size=None,
pooler_act="tanh",
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.qkv_bias = qkv_bias
self.encoder_stride = encoder_stride
self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size
self.pooler_act = pooler_act
class Dust3ROnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-4
__all__ = ["Dust3RConfig", "Dust3ROnnxConfig"]

View File

@ -0,0 +1,288 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Dust3R."""
from typing import Dict, List, Optional, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, filter_out_non_signature_kwargs, logging
from ...utils.import_utils import requires
logger = logging.get_logger(__name__)
@requires(backends=("vision",))
class Dust3RImageProcessor(BaseImageProcessor):
r"""
Constructs a Dust3R image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: Optional[bool] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size)
self.do_resize = do_resize
self.do_rescale = do_rescale
self.do_normalize = do_normalize
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_convert_rgb = do_convert_rgb
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns:
`np.ndarray`: The resized image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"])
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
@filter_out_non_signature_kwargs()
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
do_convert_rgb: Optional[bool] = None,
):
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
resizing.
resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
`PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use if `do_normalize` is set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
resample = resample if resample is not None else self.resample
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
size = size if size is not None else self.size
size_dict = get_size_dict(size)
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_rescale and is_scaled_image(images[0]):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize:
images = [
self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale:
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize:
images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["Dust3RImageProcessor"]

View File

@ -0,0 +1,827 @@
# coding=utf-8
# Copyright 2025 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Dust3R model."""
import collections.abc
import math
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
MaskedImageModelingOutput,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import auto_docstring, logging, torch_int
from .configuration_dust3r import Dust3RConfig
logger = logging.get_logger(__name__)
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings with ViT->Dust3R
class Dust3REmbeddings(nn.Module):
"""
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
"""
def __init__(self, config: Dust3RConfig, use_mask_token: bool = False) -> None:
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = Dust3RPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.patch_size
self.config = config
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
def forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
# add the [CLS] token to the embedded patch tokens
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
# Copied from transformers.models.vit.modeling_vit.ViTPatchEmbeddings with ViT->Dust3R
class Dust3RPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
f" Expected {self.num_channels} but got {num_channels}."
)
if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
# Normalize the attention scores to probabilities.
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Mask heads if we want to
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dust3R
class Dust3RSelfAttention(nn.Module):
def __init__(self, config: Dust3RConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dropout_prob = config.attention_probs_dropout_prob
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(self.query(hidden_states))
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dust3R
class Dust3RSelfOutput(nn.Module):
"""
The residual connection is defined in Dust3RLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: Dust3RConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dust3R
class Dust3RAttention(nn.Module):
def __init__(self, config: Dust3RConfig) -> None:
super().__init__()
self.attention = Dust3RSelfAttention(config)
self.output = Dust3RSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.attention.query = prune_linear_layer(self.attention.query, index)
self.attention.key = prune_linear_layer(self.attention.key, index)
self.attention.value = prune_linear_layer(self.attention.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Dust3R
class Dust3RIntermediate(nn.Module):
def __init__(self, config: Dust3RConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Dust3R
class Dust3ROutput(nn.Module):
def __init__(self, config: Dust3RConfig) -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Dust3R
class Dust3RLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: Dust3RConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = Dust3RAttention(config)
self.intermediate = Dust3RIntermediate(config)
self.output = Dust3ROutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in Dust3R, layernorm is applied before self-attention
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# first residual connection
hidden_states = attention_output + hidden_states
# in Dust3R, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
# second residual connection is done here
layer_output = self.output(layer_output, hidden_states)
outputs = (layer_output,) + outputs
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dust3R
class Dust3REncoder(nn.Module):
def __init__(self, config: Dust3RConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList([Dust3RLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@auto_docstring
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel with ViT->Dust3R,vit->dust3r
class Dust3RPreTrainedModel(PreTrainedModel):
config_class = Dust3RConfig
base_model_prefix = "dust3r"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["Dust3REmbeddings", "Dust3RLayer"]
_supports_sdpa = True
_supports_flash_attn_2 = True
_supports_flex_attn = True
_supports_attention_backend = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, Dust3REmbeddings):
module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.position_embeddings.dtype)
module.cls_token.data = nn.init.trunc_normal_(
module.cls_token.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.cls_token.dtype)
if module.mask_token is not None:
module.mask_token.data.zero_()
@auto_docstring
# Copied from transformers.models.vit.modeling_vit.ViTModel with ViT->Dust3R
class Dust3RModel(Dust3RPreTrainedModel):
def __init__(self, config: Dust3RConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
r"""
add_pooling_layer (bool, *optional*, defaults to `True`):
Whether to add a pooling layer
use_mask_token (`bool`, *optional*, defaults to `False`):
Whether to use a mask token for masked image modeling.
"""
super().__init__(config)
self.config = config
self.embeddings = Dust3REmbeddings(config, use_mask_token=use_mask_token)
self.encoder = Dust3REncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = Dust3RPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> Dust3RPatchEmbeddings:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@auto_docstring
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
return head_outputs + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->Dust3R
class Dust3RPooler(nn.Module):
def __init__(self, config: Dust3RConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
self.activation = ACT2FN[config.pooler_act]
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
@auto_docstring(
custom_intro="""
Dust3R Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://arxiv.org/abs/2111.09886).
<Tip>
Note that we provide a script to pre-train this model on custom data in our [examples
directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
</Tip>
"""
)
# Copied from transformers.models.vit.modeling_vit.ViTForMaskedImageModeling with ViT->Dust3R,vit->dust3r,google/vit-base-patch16-224-in21k->google/vit-base-patch16-224
class Dust3RForMaskedImageModeling(Dust3RPreTrainedModel):
def __init__(self, config: Dust3RConfig) -> None:
super().__init__(config)
self.dust3r = Dust3RModel(config, add_pooling_layer=False, use_mask_token=True)
self.decoder = nn.Sequential(
nn.Conv2d(
in_channels=config.hidden_size,
out_channels=config.encoder_stride**2 * config.num_channels,
kernel_size=1,
),
nn.PixelShuffle(config.encoder_stride),
)
# Initialize weights and apply final processing
self.post_init()
@auto_docstring
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, MaskedImageModelingOutput]:
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
Examples:
```python
>>> from transformers import AutoImageProcessor, Dust3RForMaskedImageModeling
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
>>> model = Dust3RForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224")
>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
>>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
>>> # create random boolean mask of shape (batch_size, num_patches)
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape)
[1, 3, 224, 224]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride):
raise ValueError(
"When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that "
"the reconstructed image has the same dimensions as the input. "
f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}."
)
outputs = self.dust3r(
pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
sequence_output = outputs[0]
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output[:, 1:]
batch_size, sequence_length, num_channels = sequence_output.shape
height = width = math.floor(sequence_length**0.5)
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
# Reconstruct pixel values
reconstructed_pixel_values = self.decoder(sequence_output)
masked_im_loss = None
if bool_masked_pos is not None:
size = self.config.image_size // self.config.patch_size
bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
mask = (
bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
.repeat_interleave(self.config.patch_size, 2)
.unsqueeze(1)
.contiguous()
)
reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
if not return_dict:
output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedImageModelingOutput(
loss=masked_im_loss,
reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@auto_docstring(
custom_intro="""
Dust3R Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
the [CLS] token) e.g. for ImageNet.
<Tip>
Note that it's possible to fine-tune Dust3R on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
"""
)
# Copied from transformers.models.vit.modeling_vit.ViTForImageClassification with ViT->Dust3R,vit->dust3r
class Dust3RForImageClassification(Dust3RPreTrainedModel):
def __init__(self, config: Dust3RConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.dust3r = Dust3RModel(config, add_pooling_layer=False)
# Classifier head
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
# Initialize weights and apply final processing
self.post_init()
@auto_docstring
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.dust3r(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output[:, 0, :])
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = ["Dust3RForImageClassification", "Dust3RForMaskedImageModeling", "Dust3RModel", "Dust3RPreTrainedModel"]

View File

@ -0,0 +1,266 @@
# TODO:
"""
Use modular on ViT or other transformers vision models you think can fit Dust3R
We should replicate the following: https://github.com/ibaiGorordo/dust3r-pytorch-inference-minimal/blob/main/dust3r/dust3r.py
We are using multiple "blocks" such as Dust3rEncoder, Dust3rDecoder etc.., I would suggest using inheritence, and inheriting
from say Vit: something like this:
say I want to replciate the Dust3rEncoder, i would do something like this:
class Dust3rEncoder(ViTEncoder):
# my custom implementation of dust3r
class Dust3RPreTrainedModel(VitPretrainedModel):
pass
class Dust3RModel(Dust3RPreTrainedModel):
def __init__(....):
self.encoder = Dust3rEncoder(...)
self.decoder = Dust3rDecoder(...)
self.head = Dust3rHead(...)
def forward():
# add forward logic similar to https://github.com/ibaiGorordo/dust3r-pytorch-inference-minimal/blob/main/dust3r/dust3r.py#L50
# test your created model first
# random weights is OK for now , let's first make sure a first pass works
"""
import torch
import torch.nn as nn
try:
from ..vit.modeling_vit import (
ViTEncoder,
ViTEmbeddings,
ViTPreTrainedModel,
)
from .configuration_dust3r import Dust3RConfig
except ImportError:
# Fallback for direct execution
import sys
import os
# Add the transformers src directory to path
transformers_path = os.path.join(os.path.dirname(__file__), '..', '..', '..')
sys.path.insert(0, transformers_path)
from transformers.models.vit.modeling_vit import (
ViTEncoder,
ViTEmbeddings,
ViTPreTrainedModel,
)
from transformers.models.dust3r.configuration_dust3r import Dust3RConfig
try:
from .third_party import RoPE2D # type: ignore
except (ImportError, ModuleNotFoundError):
class RoPE2D: # pylint: disable=too-few-public-methods
def __init__(self, *_, **__):
pass
# -----------------------------------------------------------------------------
# Simple Encoder (inherits from ViT)
# -----------------------------------------------------------------------------
class Dust3rEncoder(ViTEncoder):
"""Simple encoder that inherits from ViTEncoder"""
def __init__(self, config):
super().__init__(config)
# Add any Dust3R-specific modifications here if needed
# For now, just use ViT encoder as-is
# -----------------------------------------------------------------------------
# Simple Decoder
# -----------------------------------------------------------------------------
class Dust3rDecoder(nn.Module):
"""Simple decoder following the reference implementation structure."""
def __init__(self, config):
super().__init__()
# Simple linear projection for decoder
self.decoder_embed = nn.Linear(config.hidden_size, config.hidden_size)
# Simple decoder layers (placeholder for now)
self.decoder_layers = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
batch_first=True
) for _ in range(6) # 6 decoder layers like in reference
])
self.norm = nn.LayerNorm(config.hidden_size)
def forward(self, encoder_features1, encoder_features2):
"""
Simple forward pass that processes two image features.
Args:
encoder_features1: Features from first image
encoder_features2: Features from second image
"""
# Project to decoder dimension
dec_feat1 = self.decoder_embed(encoder_features1)
dec_feat2 = self.decoder_embed(encoder_features2)
# Simple cross-attention between the two feature sets
for layer in self.decoder_layers:
dec_feat1 = layer(dec_feat1, dec_feat2)
dec_feat2 = layer(dec_feat2, dec_feat1)
# Apply final norm
dec_feat1 = self.norm(dec_feat1)
dec_feat2 = self.norm(dec_feat2)
return dec_feat1, dec_feat2
# -----------------------------------------------------------------------------
# Simple Head
# -----------------------------------------------------------------------------
class Dust3rHead(nn.Module):
"""Simple head for outputting depth and confidence maps."""
def __init__(self, config):
super().__init__()
# Simple heads for depth and confidence (like reference implementation)
self.depth_head = nn.Linear(config.hidden_size, 1)
self.conf_head = nn.Linear(config.hidden_size, 1)
def forward(self, decoder_features):
"""
Args:
decoder_features: Output from decoder
"""
depth = self.depth_head(decoder_features)
confidence = self.conf_head(decoder_features)
return depth, confidence
# -----------------------------------------------------------------------------
# Main Model (inherits from ViT)
# -----------------------------------------------------------------------------
class Dust3RPreTrainedModel(ViTPreTrainedModel):
"""Inherits from ViTPreTrainedModel"""
base_model_prefix = "dust3r"
class Dust3RModel(Dust3RPreTrainedModel):
"""
Main Dust3R model:
- encoder = Dust3rEncoder (inherits from ViT)
- decoder = Dust3rDecoder
- head = Dust3rHead
"""
def __init__(self, config):
super().__init__(config)
# Embeddings (reuse ViT)
self.embeddings = ViTEmbeddings(config)
# Encoder (inherits from ViTEncoder)
self.encoder = Dust3rEncoder(config)
# Decoder (simple custom implementation)
self.decoder = Dust3rDecoder(config)
# Head (simple custom implementation)
self.head = Dust3rHead(config)
self.post_init()
def forward(self, pixel_values1: torch.Tensor, pixel_values2: torch.Tensor):
"""
Forward logic similar to the reference implementation.
Args:
pixel_values1: First image tensor (B, C, H, W)
pixel_values2: Second image tensor (B, C, H, W)
"""
# Encode both images
embeddings1 = self.embeddings(pixel_values1)
embeddings2 = self.embeddings(pixel_values2)
encoder_output1 = self.encoder(embeddings1)
encoder_output2 = self.encoder(embeddings2)
# Take the last hidden state
features1 = encoder_output1.last_hidden_state
features2 = encoder_output2.last_hidden_state
# Decode with cross-attention between images
decoder_feat1, decoder_feat2 = self.decoder(features1, features2)
# Apply heads to get final outputs
depth1, conf1 = self.head(decoder_feat1)
depth2, conf2 = self.head(decoder_feat2)
return {
'depth1': depth1,
'confidence1': conf1,
'depth2': depth2,
'confidence2': conf2
}
# -----------------------------------------------------------------------------
# Simple test function
# -----------------------------------------------------------------------------
def test_dust3r_model():
"""Test the basic Dust3R model with random inputs."""
config = Dust3RConfig(
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.02,
layer_norm_eps=1e-6,
image_size=224,
patch_size=16,
num_channels=3,
qkv_bias=True,
)
print("Testing basic Dust3R model...")
print("Architecture:")
print(" ✓ Simple Dust3rDecoder with cross-attention")
print(" ✓ Simple Dust3rHead for depth/confidence")
print()
# Create model
model = Dust3RModel(config)
model.eval()
# Test with two images
batch_size = 2
pixel_values1 = torch.randn(batch_size, 3, 224, 224)
pixel_values2 = torch.randn(batch_size, 3, 224, 224)
with torch.no_grad():
outputs = model(pixel_values1, pixel_values2)
print("Output shapes:")
for key, value in outputs.items():
print(f" {key}: {value.shape}")
print("Basic test passed! Simple Dust3R model works with random weights.")
if __name__ == "__main__":
test_dust3r_model()

View File

View File

@ -0,0 +1,112 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_vision_available():
from transformers import Dust3RImageProcessor
if is_torchvision_available():
from transformers import Dust3RImageProcessorFast
class Dust3RImageProcessingTester:
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_resize=True,
size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
def prepare_image_processor_dict(self):
return {
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_normalize": self.do_normalize,
"do_resize": self.do_resize,
"size": self.size,
}
def expected_output_image_shape(self, images):
return self.num_channels, self.size["height"], self.size["width"]
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
class Dust3RImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = Dust3RImageProcessor if is_vision_available() else None
fast_image_processing_class = Dust3RImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = Dust3RImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
def test_image_processor_from_dict_with_kwargs(self):
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})

View File

@ -0,0 +1,329 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Dust3R model."""
import unittest
from transformers import Dust3RConfig
from transformers.testing_utils import (
require_accelerate,
require_torch,
require_torch_accelerator,
require_torch_fp16,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from torch import nn
from transformers import Dust3RForImageClassification, Dust3RForMaskedImageModeling, Dust3RModel
if is_vision_available():
from PIL import Image
from transformers import Dust3RImageProcessor
class Dust3RModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
scope=None,
encoder_stride=2,
mask_ratio=0.5,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.encoder_stride = encoder_stride
self.attn_implementation = attn_implementation
# in Dust3R, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return Dust3RConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):
model = Dust3RModel(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
model = Dust3RForMaskedImageModeling(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
)
# test greyscale images
config.num_channels = 1
model = Dust3RForMaskedImageModeling(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = Dust3RForImageClassification(config)
model.to(torch_device)
model.eval()
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = Dust3RForImageClassification(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
pixel_values,
labels,
) = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class Dust3RModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as Dust3R does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (
(
Dust3RModel,
Dust3RForImageClassification,
Dust3RForMaskedImageModeling,
)
if is_torch_available()
else ()
)
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = Dust3RModelTester(self)
self.config_tester = ConfigTester(self, config_class=Dust3RConfig, has_text_modality=False, hidden_size=37)
@unittest.skip(
"Since `torch==2.3+cu121`, although this test passes, many subsequent tests have `CUDA error: misaligned address`."
"If `nvidia-xxx-cu118` are also installed, no failure (even with `torch==2.3+cu121`)."
)
def test_multi_gpu_data_parallel_forward(self):
super().test_multi_gpu_data_parallel_forward()
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="Dust3R does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_masked_image_modeling(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
model_name = "google/dust3r-base-patch16-224"
model = Dust3RModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_torch
@require_vision
class Dust3RModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return Dust3RImageProcessor.from_pretrained("google/dust3r-base-patch16-224") if is_vision_available() else None
@slow
def test_inference_image_classification_head(self):
model = Dust3RForImageClassification.from_pretrained("google/dust3r-base-patch16-224").to(torch_device)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
# verify the logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to(torch_device)
torch.testing.assert_close(outputs.logits[0, :3], expected_slice, rtol=1e-4, atol=1e-4)
@slow
def test_inference_interpolate_pos_encoding(self):
# Dust3R models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
model = Dust3RModel.from_pretrained("facebook/dino-dust3rs8").to(torch_device)
image_processor = Dust3RImageProcessor.from_pretrained("facebook/dino-dust3rs8", size=480)
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)
# verify the logits
expected_shape = torch.Size((1, 3601, 384))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.tensor(
[[4.2325, 4.3882, -6.6678], [4.5372, 1.8933, -6.7355], [4.4454, 0.8514, -5.8747]]
).to(torch_device)
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-3, atol=1e-3)
@slow
@require_accelerate
@require_torch_accelerator
@require_torch_fp16
def test_inference_fp16(self):
r"""
A small test to make sure that inference work in half precision without any problem.
"""
model = Dust3RModel.from_pretrained("facebook/dino-dust3rs8", torch_dtype=torch.float16, device_map="auto")
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass to make sure inference works in fp16
with torch.no_grad():
_ = model(pixel_values)