Fix DPT /Dinov2 sdpa regression on main (#33660)

* fallback to eager if output attentions.

* fix copies
This commit is contained in:
Pablo Montalvo 2024-09-23 11:49:16 +02:00 committed by GitHub
parent 9eb93854b9
commit b7c381f011
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -231,7 +231,6 @@ class Dinov2SelfAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Dinov2
class Dinov2SdpaSelfAttention(Dinov2SelfAttention):
def __init__(self, config: Dinov2Config) -> None:
super().__init__(config)
@ -240,6 +239,16 @@ class Dinov2SdpaSelfAttention(Dinov2SelfAttention):
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Dinov2Model is using Dinov2SdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states, head_mask=head_mask, output_attentions=output_attentions
)
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))