Cleanup Attention class for Siglip and dependent models (#39040)

* cleanup attention class

* More models

* more models

* Changes

* make style

* Should fix CI

* This should work 🙏
This commit is contained in:
Yaswanth Gali 2025-06-27 15:44:09 +05:30 committed by GitHub
parent 1ccc73dee9
commit 0d66ef7792
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 23 additions and 97 deletions

View File

@ -623,6 +623,7 @@ def eager_attention_forward(
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights

View File

@ -275,6 +275,7 @@ def eager_attention_forward(
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights

View File

@ -606,7 +606,7 @@ class Emu3VQVAEAttentionBlock(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
@ -622,13 +622,7 @@ class Emu3VQVAEAttentionBlock(nn.Module):
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
@ -644,9 +638,6 @@ class Emu3VQVAEAttentionBlock(nn.Module):
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights

View File

@ -620,6 +620,7 @@ def eager_attention_forward(
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights

View File

@ -185,6 +185,7 @@ def eager_attention_forward(
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights

View File

@ -219,7 +219,7 @@ class Idefics2VisionAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
@ -235,13 +235,7 @@ class Idefics2VisionAttention(nn.Module):
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
@ -257,9 +251,6 @@ class Idefics2VisionAttention(nn.Module):
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights

View File

@ -216,7 +216,7 @@ class Idefics3VisionAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
@ -232,13 +232,7 @@ class Idefics3VisionAttention(nn.Module):
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
@ -254,9 +248,6 @@ class Idefics3VisionAttention(nn.Module):
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights

View File

@ -21,7 +21,6 @@ from typing import Any, Callable, Optional, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn.init import _calculate_fan_in_and_fan_out
@ -31,13 +30,10 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
from ...utils import ModelOutput, auto_docstring, can_return_tuple, torch_int
from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
logger = logging.get_logger(__name__)
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
@ -372,7 +368,7 @@ class SiglipAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
@ -388,13 +384,7 @@ class SiglipAttention(nn.Module):
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
@ -410,9 +400,6 @@ class SiglipAttention(nn.Module):
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights

View File

@ -35,13 +35,10 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
from ...utils import ModelOutput, auto_docstring, can_return_tuple
from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
logger = logging.get_logger(__name__)
@dataclass
@auto_docstring(
custom_intro="""
@ -266,7 +263,7 @@ class Siglip2Attention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
@ -282,13 +279,7 @@ class Siglip2Attention(nn.Module):
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
@ -304,9 +295,6 @@ class Siglip2Attention(nn.Module):
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights

View File

@ -186,7 +186,7 @@ class SmolVLMVisionAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
@ -202,13 +202,7 @@ class SmolVLMVisionAttention(nn.Module):
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
@ -224,9 +218,6 @@ class SmolVLMVisionAttention(nn.Module):
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights

View File

@ -1008,8 +1008,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel):
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
**flash_attn_kwargs: flash attention related parameters.
"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
@ -1084,10 +1082,6 @@ class T5GemmaEncoderModel(T5GemmaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutput:
r"""
**flash_attn_kwargs: flash attention related parameters.
"""
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
@ -1162,7 +1156,6 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
@ -1234,7 +1227,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
@auto_docstring
class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None):
"""
r"""
is_encoder_decoder (`Optional`, *optional*):
Whether use encoder_decoder for sequence classification. When set to False, only encoder is used.
"""
@ -1286,7 +1279,6 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -1382,7 +1374,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
@auto_docstring
class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None):
"""
r"""
is_encoder_decoder (`Optional`, *optional*):
Whether use encoder_decoder for token classification. When set to False, only encoder is used.
"""
@ -1435,7 +1427,6 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -955,8 +955,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel):
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
**flash_attn_kwargs: flash attention related parameters.
"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
@ -1031,10 +1029,6 @@ class T5GemmaEncoderModel(T5GemmaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutput:
r"""
**flash_attn_kwargs: flash attention related parameters.
"""
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
@ -1109,7 +1103,6 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
@ -1181,7 +1174,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
@auto_docstring
class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None):
"""
r"""
is_encoder_decoder (`Optional`, *optional*):
Whether use encoder_decoder for sequence classification. When set to False, only encoder is used.
"""
@ -1233,7 +1226,6 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
@ -1329,7 +1321,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
@auto_docstring
class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None):
"""
r"""
is_encoder_decoder (`Optional`, *optional*):
Whether use encoder_decoder for token classification. When set to False, only encoder is used.
"""
@ -1382,7 +1374,6 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If

View File

@ -240,6 +240,7 @@ def eager_attention_forward(
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights