mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Sdpa dino v2 (#33403)
* add sdpa to dinov2 * fixup * add dinov2 to sdpa doc * update doc order * [run-slow] dinov2 * common to eager * [run-slow] dinov2 * update attn implementation in common * update test_modeling_dinov2 to have mask_ration, num_masks and mask_length similar to vit * [run-slow] dinov2 --------- Co-authored-by: Avishai Elmakies <avishai.elma@cs.huji.ac.il>
This commit is contained in:
parent
e71bf70e33
commit
78b2929c05
@ -217,6 +217,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
|
||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
|
||||
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
|
||||
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
|
||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||
@ -275,7 +276,6 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
|
||||
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models.
|
||||
|
@ -231,6 +231,38 @@ class Dinov2SelfAttention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Dinov2
|
||||
class Dinov2SdpaSelfAttention(Dinov2SelfAttention):
|
||||
def __init__(self, config: Dinov2Config) -> None:
|
||||
super().__init__(config)
|
||||
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
head_mask,
|
||||
self.attention_probs_dropout_prob if self.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
return context_layer, None
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
|
||||
class Dinov2SelfOutput(nn.Module):
|
||||
"""
|
||||
@ -290,6 +322,13 @@ class Dinov2Attention(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Dinov2
|
||||
class Dinov2SdpaAttention(Dinov2Attention):
|
||||
def __init__(self, config: Dinov2Config) -> None:
|
||||
super().__init__(config)
|
||||
self.attention = Dinov2SdpaSelfAttention(config)
|
||||
|
||||
|
||||
class Dinov2LayerScale(nn.Module):
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
@ -371,6 +410,12 @@ class Dinov2SwiGLUFFN(nn.Module):
|
||||
return self.weights_out(hidden)
|
||||
|
||||
|
||||
DINOV2_ATTENTION_CLASSES = {
|
||||
"eager": Dinov2Attention,
|
||||
"sdpa": Dinov2SdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class Dinov2Layer(nn.Module):
|
||||
"""This corresponds to the Block class in the original implementation."""
|
||||
|
||||
@ -378,7 +423,7 @@ class Dinov2Layer(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.attention = Dinov2Attention(config)
|
||||
self.attention = DINOV2_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.layer_scale1 = Dinov2LayerScale(config)
|
||||
self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
|
||||
|
||||
@ -485,6 +530,7 @@ class Dinov2PreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Dinov2SwiGLUFFN"]
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
@ -65,6 +65,8 @@ class Dinov2ModelTester:
|
||||
type_sequence_label_size=10,
|
||||
initializer_range=0.02,
|
||||
scope=None,
|
||||
attn_implementation="eager",
|
||||
mask_ratio=0.5,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -83,10 +85,14 @@ class Dinov2ModelTester:
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.attn_implementation = attn_implementation
|
||||
self.mask_ratio = mask_ratio
|
||||
|
||||
# in Dinov2, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = num_patches + 1
|
||||
self.num_masks = int(self.mask_ratio * self.seq_length)
|
||||
self.mask_length = num_patches
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
@ -113,6 +119,7 @@ class Dinov2ModelTester:
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
|
Loading…
Reference in New Issue
Block a user