diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8e9e559cbbc..a6232787f46 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1624,6 +1624,7 @@ else: _import_structure["models.dinov2"].extend( [ "DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST", + "Dinov2Backbone", "Dinov2ForImageClassification", "Dinov2Model", "Dinov2PreTrainedModel", @@ -5540,6 +5541,7 @@ if TYPE_CHECKING: ) from .models.dinov2 import ( DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST, + Dinov2Backbone, Dinov2ForImageClassification, Dinov2Model, Dinov2PreTrainedModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7661f5154f3..fa9a483eb31 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1056,6 +1056,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( ("convnext", "ConvNextBackbone"), ("convnextv2", "ConvNextV2Backbone"), ("dinat", "DinatBackbone"), + ("dinov2", "Dinov2Backbone"), ("focalnet", "FocalNetBackbone"), ("maskformer-swin", "MaskFormerSwinBackbone"), ("nat", "NatBackbone"), diff --git a/src/transformers/models/dinov2/__init__.py b/src/transformers/models/dinov2/__init__.py index 524e77407bb..01d02a9e65f 100644 --- a/src/transformers/models/dinov2/__init__.py +++ b/src/transformers/models/dinov2/__init__.py @@ -35,6 +35,7 @@ else: "Dinov2ForImageClassification", "Dinov2Model", "Dinov2PreTrainedModel", + "Dinov2Backbone", ] if TYPE_CHECKING: @@ -48,6 +49,7 @@ if TYPE_CHECKING: else: from .modeling_dinov2 import ( DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST, + Dinov2Backbone, Dinov2ForImageClassification, Dinov2Model, Dinov2PreTrainedModel, diff --git a/src/transformers/models/dinov2/configuration_dinov2.py b/src/transformers/models/dinov2/configuration_dinov2.py index 169d25ec8e9..3981eb457af 100644 --- a/src/transformers/models/dinov2/configuration_dinov2.py +++ b/src/transformers/models/dinov2/configuration_dinov2.py @@ -22,6 +22,7 @@ from packaging import version from ...configuration_utils import PretrainedConfig from ...onnx import OnnxConfig from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices logger = logging.get_logger(__name__) @@ -31,7 +32,7 @@ DINOV2_PRETRAINED_CONFIG_ARCHIVE_MAP = { } -class Dinov2Config(PretrainedConfig): +class Dinov2Config(PretrainedConfig, BackboneConfigMixin): r""" This is the configuration class to store the configuration of a [`Dinov2Model`]. It is used to instantiate an Dinov2 model according to the specified arguments, defining the model architecture. Instantiating a configuration @@ -41,7 +42,6 @@ class Dinov2Config(PretrainedConfig): Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - Args: hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. @@ -76,6 +76,20 @@ class Dinov2Config(PretrainedConfig): Stochastic depth rate per sample (when applied in the main path of residual layers). use_swiglu_ffn (`bool`, *optional*, defaults to `False`): Whether to use the SwiGLU feedforward neural network. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. + apply_layernorm (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization to the feature maps in case the model is used as backbone. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in + case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size, + seq_len, hidden_size)`. Example: @@ -111,6 +125,10 @@ class Dinov2Config(PretrainedConfig): layerscale_value=1.0, drop_path_rate=0.0, use_swiglu_ffn=False, + out_features=None, + out_indices=None, + apply_layernorm=True, + reshape_hidden_states=True, **kwargs, ): super().__init__(**kwargs) @@ -131,6 +149,12 @@ class Dinov2Config(PretrainedConfig): self.layerscale_value = layerscale_value self.drop_path_rate = drop_path_rate self.use_swiglu_ffn = use_swiglu_ffn + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + self.apply_layernorm = apply_layernorm + self.reshape_hidden_states = reshape_hidden_states class Dinov2OnnxConfig(OnnxConfig): diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 5af80fc4503..8816dbe49c7 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_outputs import ( + BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, @@ -37,7 +38,9 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, logging, + replace_return_docstrings, ) +from ...utils.backbone_utils import BackboneMixin from .configuration_dinov2 import Dinov2Config @@ -48,11 +51,10 @@ _CONFIG_FOR_DOC = "Dinov2Config" # Base docstring _CHECKPOINT_FOR_DOC = "facebook/dinov2-base" -_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] +_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] # Image classification docstring -_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base-patch16-224" -_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" +_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base" DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [ @@ -111,7 +113,7 @@ class Dinov2Embeddings(nn.Module): patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) - def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor) -> torch.Tensor: + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape embeddings = self.patch_embeddings(pixel_values) @@ -691,7 +693,6 @@ class Dinov2ForImageClassification(Dinov2PreTrainedModel): checkpoint=_IMAGE_CLASS_CHECKPOINT, output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC, - expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, ) def forward( self, @@ -762,3 +763,103 @@ class Dinov2ForImageClassification(Dinov2PreTrainedModel): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings( + """ + Dinov2 backbone, to be used with frameworks like DETR and MaskFormer. + """, + DINOV2_START_DOCSTRING, +) +class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state[:, 1:, :].reshape( + batch_size, width // patch_size, height // patch_size, -1 + ) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions if output_attentions else None, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 7189861619b..1b262857025 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -2689,6 +2689,13 @@ class DinatPreTrainedModel(metaclass=DummyObject): DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = None +class Dinov2Backbone(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Dinov2ForImageClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/dinov2/test_modeling_dinov2.py b/tests/models/dinov2/test_modeling_dinov2.py index cf7ff95b572..fa20833823e 100644 --- a/tests/models/dinov2/test_modeling_dinov2.py +++ b/tests/models/dinov2/test_modeling_dinov2.py @@ -27,6 +27,7 @@ from transformers.testing_utils import ( ) from transformers.utils import cached_property, is_torch_available, is_vision_available +from ...test_backbone_common import BackboneTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -36,7 +37,7 @@ if is_torch_available(): import torch from torch import nn - from transformers import Dinov2ForImageClassification, Dinov2Model + from transformers import Dinov2Backbone, Dinov2ForImageClassification, Dinov2Model from transformers.models.dinov2.modeling_dinov2 import DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST @@ -123,6 +124,53 @@ class Dinov2ModelTester: result = model(pixel_values) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + def create_and_check_backbone(self, config, pixel_values, labels): + model = Dinov2Backbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify hidden states + self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) + expected_size = self.image_size // config.patch_size + self.parent.assertListEqual( + list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], expected_size, expected_size] + ) + + # verify channels + self.parent.assertEqual(len(model.channels), len(config.out_features)) + + # verify backbone works with out_features=None + config.out_features = None + model = Dinov2Backbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual( + list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], expected_size, expected_size] + ) + + # verify channels + self.parent.assertEqual(len(model.channels), 1) + + # verify backbone works with apply_layernorm=False and reshape_hidden_states=False + config.apply_layernorm = False + config.reshape_hidden_states = False + + model = Dinov2Backbone(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + + # verify feature maps + self.parent.assertEqual(len(result.feature_maps), 1) + self.parent.assertListEqual( + list(result.feature_maps[0].shape), [self.batch_size, self.seq_length, self.hidden_size] + ) + def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size model = Dinov2ForImageClassification(config) @@ -159,7 +207,15 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): attention_mask and seq_length. """ - all_model_classes = (Dinov2Model, Dinov2ForImageClassification) if is_torch_available() else () + all_model_classes = ( + ( + Dinov2Model, + Dinov2ForImageClassification, + Dinov2Backbone, + ) + if is_torch_available() + else () + ) pipeline_model_mapping = ( {"feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification} if is_torch_available() @@ -207,10 +263,18 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_backbone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_backbone(*config_and_inputs) + def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + @unittest.skip(reason="Dinov2 does not support feedforward chunking yet") + def test_feed_forward_chunking(self): + pass + @slow def test_model_from_pretrained(self): for model_name in DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: @@ -252,3 +316,14 @@ class Dinov2ModelIntegrationTest(unittest.TestCase): device=torch_device, ) self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) + + +@require_torch +class Dinov2BackboneTest(unittest.TestCase, BackboneTesterMixin): + all_model_classes = (Dinov2Backbone,) if is_torch_available() else () + config_class = Dinov2Config + + has_attentions = False + + def setUp(self): + self.model_tester = Dinov2ModelTester(self) diff --git a/utils/check_repo.py b/utils/check_repo.py index 9da984c2427..c46b82b7c67 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -956,6 +956,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ "ConvNextBackbone", "ConvNextV2Backbone", "DinatBackbone", + "Dinov2Backbone", "FocalNetBackbone", "MaskFormerSwinBackbone", "MaskFormerSwinConfig",