[DINOv2] Add backbone class (#25520)

* First draft

* More improvements

* Fix all tests

* More improvements

* Add backbone test

* Improve docstring

* Address comments

* Rename attribute

* Remove expected output

* Update src/transformers/models/dinov2/modeling_dinov2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix style

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
NielsRogge 2023-08-29 12:05:27 +02:00 committed by GitHub
parent 4c21da5e34
commit 77713d11f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 222 additions and 9 deletions

View File

@ -1624,6 +1624,7 @@ else:
_import_structure["models.dinov2"].extend(
[
"DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Dinov2Backbone",
"Dinov2ForImageClassification",
"Dinov2Model",
"Dinov2PreTrainedModel",
@ -5540,6 +5541,7 @@ if TYPE_CHECKING:
)
from .models.dinov2 import (
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST,
Dinov2Backbone,
Dinov2ForImageClassification,
Dinov2Model,
Dinov2PreTrainedModel,

View File

@ -1056,6 +1056,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
("convnext", "ConvNextBackbone"),
("convnextv2", "ConvNextV2Backbone"),
("dinat", "DinatBackbone"),
("dinov2", "Dinov2Backbone"),
("focalnet", "FocalNetBackbone"),
("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"),

View File

@ -35,6 +35,7 @@ else:
"Dinov2ForImageClassification",
"Dinov2Model",
"Dinov2PreTrainedModel",
"Dinov2Backbone",
]
if TYPE_CHECKING:
@ -48,6 +49,7 @@ if TYPE_CHECKING:
else:
from .modeling_dinov2 import (
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST,
Dinov2Backbone,
Dinov2ForImageClassification,
Dinov2Model,
Dinov2PreTrainedModel,

View File

@ -22,6 +22,7 @@ from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__)
@ -31,7 +32,7 @@ DINOV2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
}
class Dinov2Config(PretrainedConfig):
class Dinov2Config(PretrainedConfig, BackboneConfigMixin):
r"""
This is the configuration class to store the configuration of a [`Dinov2Model`]. It is used to instantiate an
Dinov2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
@ -41,7 +42,6 @@ class Dinov2Config(PretrainedConfig):
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
@ -76,6 +76,20 @@ class Dinov2Config(PretrainedConfig):
Stochastic depth rate per sample (when applied in the main path of residual layers).
use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
Whether to use the SwiGLU feedforward neural network.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage.
apply_layernorm (`bool`, *optional*, defaults to `True`):
Whether to apply layer normalization to the feature maps in case the model is used as backbone.
reshape_hidden_states (`bool`, *optional*, defaults to `True`):
Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
seq_len, hidden_size)`.
Example:
@ -111,6 +125,10 @@ class Dinov2Config(PretrainedConfig):
layerscale_value=1.0,
drop_path_rate=0.0,
use_swiglu_ffn=False,
out_features=None,
out_indices=None,
apply_layernorm=True,
reshape_hidden_states=True,
**kwargs,
):
super().__init__(**kwargs)
@ -131,6 +149,12 @@ class Dinov2Config(PretrainedConfig):
self.layerscale_value = layerscale_value
self.drop_path_rate = drop_path_rate
self.use_swiglu_ffn = use_swiglu_ffn
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)
self.apply_layernorm = apply_layernorm
self.reshape_hidden_states = reshape_hidden_states
class Dinov2OnnxConfig(OnnxConfig):

View File

@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BackboneOutput,
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
@ -37,7 +38,9 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin
from .configuration_dinov2 import Dinov2Config
@ -48,11 +51,10 @@ _CONFIG_FOR_DOC = "Dinov2Config"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -111,7 +113,7 @@ class Dinov2Embeddings(nn.Module):
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor) -> torch.Tensor:
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values)
@ -691,7 +693,6 @@ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
@ -762,3 +763,103 @@ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
""",
DINOV2_START_DOCSTRING,
)
class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
self.embeddings = Dinov2Embeddings(config)
self.encoder = Dinov2Encoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
return self.embeddings.patch_embeddings
@add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
"""
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
>>> model = AutoBackbone.from_pretrained(
... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
... )
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 768, 16, 16]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
embedding_output = self.embeddings(pixel_values)
outputs = self.encoder(
embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict
)
hidden_states = outputs.hidden_states if return_dict else outputs[1]
feature_maps = ()
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
if self.config.apply_layernorm:
hidden_state = self.layernorm(hidden_state)
if self.config.reshape_hidden_states:
batch_size, _, height, width = pixel_values.shape
patch_size = self.config.patch_size
hidden_state = hidden_state[:, 1:, :].reshape(
batch_size, width // patch_size, height // patch_size, -1
)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)
if not return_dict:
if output_hidden_states:
output = (feature_maps,) + outputs[1:]
else:
output = (feature_maps,) + outputs[2:]
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions if output_attentions else None,
)

View File

@ -2689,6 +2689,13 @@ class DinatPreTrainedModel(metaclass=DummyObject):
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = None
class Dinov2Backbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Dinov2ForImageClassification(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -27,6 +27,7 @@ from transformers.testing_utils import (
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@ -36,7 +37,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import Dinov2ForImageClassification, Dinov2Model
from transformers import Dinov2Backbone, Dinov2ForImageClassification, Dinov2Model
from transformers.models.dinov2.modeling_dinov2 import DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST
@ -123,6 +124,53 @@ class Dinov2ModelTester:
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_backbone(self, config, pixel_values, labels):
model = Dinov2Backbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
expected_size = self.image_size // config.patch_size
self.parent.assertListEqual(
list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], expected_size, expected_size]
)
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
# verify backbone works with out_features=None
config.out_features = None
model = Dinov2Backbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(
list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], expected_size, expected_size]
)
# verify channels
self.parent.assertEqual(len(model.channels), 1)
# verify backbone works with apply_layernorm=False and reshape_hidden_states=False
config.apply_layernorm = False
config.reshape_hidden_states = False
model = Dinov2Backbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(
list(result.feature_maps[0].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = Dinov2ForImageClassification(config)
@ -159,7 +207,15 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
attention_mask and seq_length.
"""
all_model_classes = (Dinov2Model, Dinov2ForImageClassification) if is_torch_available() else ()
all_model_classes = (
(
Dinov2Model,
Dinov2ForImageClassification,
Dinov2Backbone,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{"feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification}
if is_torch_available()
@ -207,10 +263,18 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@unittest.skip(reason="Dinov2 does not support feedforward chunking yet")
def test_feed_forward_chunking(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
@ -252,3 +316,14 @@ class Dinov2ModelIntegrationTest(unittest.TestCase):
device=torch_device,
)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
@require_torch
class Dinov2BackboneTest(unittest.TestCase, BackboneTesterMixin):
all_model_classes = (Dinov2Backbone,) if is_torch_available() else ()
config_class = Dinov2Config
has_attentions = False
def setUp(self):
self.model_tester = Dinov2ModelTester(self)

View File

@ -956,6 +956,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
"ConvNextBackbone",
"ConvNextV2Backbone",
"DinatBackbone",
"Dinov2Backbone",
"FocalNetBackbone",
"MaskFormerSwinBackbone",
"MaskFormerSwinConfig",