mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Cleanup Attention class for Siglip and dependent models (#39040)
* cleanup attention class
* More models
* more models
* Changes
* make style
* Should fix CI
* This should work 🙏
This commit is contained in:
parent
1ccc73dee9
commit
0d66ef7792
@ -623,6 +623,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
@ -275,6 +275,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
@ -606,7 +606,7 @@ class Emu3VQVAEAttentionBlock(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -622,12 +622,6 @@ class Emu3VQVAEAttentionBlock(nn.Module):
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
@ -644,9 +638,6 @@ class Emu3VQVAEAttentionBlock(nn.Module):
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
@ -620,6 +620,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
@ -185,6 +185,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
@ -219,7 +219,7 @@ class Idefics2VisionAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -235,12 +235,6 @@ class Idefics2VisionAttention(nn.Module):
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
@ -257,9 +251,6 @@ class Idefics2VisionAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
@ -216,7 +216,7 @@ class Idefics3VisionAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -232,12 +232,6 @@ class Idefics3VisionAttention(nn.Module):
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
@ -254,9 +248,6 @@ class Idefics3VisionAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
@ -21,7 +21,6 @@ from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
@ -31,13 +30,10 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, torch_int
|
||||
from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
@ -372,7 +368,7 @@ class SiglipAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -388,12 +384,6 @@ class SiglipAttention(nn.Module):
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
@ -410,9 +400,6 @@ class SiglipAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
@ -35,13 +35,10 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple
|
||||
from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
@ -266,7 +263,7 @@ class Siglip2Attention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -282,12 +279,6 @@ class Siglip2Attention(nn.Module):
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
@ -304,9 +295,6 @@ class Siglip2Attention(nn.Module):
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
@ -186,7 +186,7 @@ class SmolVLMVisionAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
@ -202,12 +202,6 @@ class SmolVLMVisionAttention(nn.Module):
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and output_attentions:
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
@ -224,9 +218,6 @@ class SmolVLMVisionAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
@ -1008,8 +1008,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel):
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
|
||||
**flash_attn_kwargs: flash attention related parameters.
|
||||
"""
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
@ -1084,10 +1082,6 @@ class T5GemmaEncoderModel(T5GemmaPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutput:
|
||||
r"""
|
||||
**flash_attn_kwargs: flash attention related parameters.
|
||||
"""
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -1162,7 +1156,6 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
@ -1234,7 +1227,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
@auto_docstring
|
||||
class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None):
|
||||
"""
|
||||
r"""
|
||||
is_encoder_decoder (`Optional`, *optional*):
|
||||
Whether use encoder_decoder for sequence classification. When set to False, only encoder is used.
|
||||
"""
|
||||
@ -1286,7 +1279,6 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
@ -1382,7 +1374,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
@auto_docstring
|
||||
class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None):
|
||||
"""
|
||||
r"""
|
||||
is_encoder_decoder (`Optional`, *optional*):
|
||||
Whether use encoder_decoder for token classification. When set to False, only encoder is used.
|
||||
"""
|
||||
@ -1435,7 +1427,6 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
|
@ -955,8 +955,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel):
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
|
||||
**flash_attn_kwargs: flash attention related parameters.
|
||||
"""
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
@ -1031,10 +1029,6 @@ class T5GemmaEncoderModel(T5GemmaPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutput:
|
||||
r"""
|
||||
**flash_attn_kwargs: flash attention related parameters.
|
||||
"""
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@ -1109,7 +1103,6 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
@ -1181,7 +1174,7 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin):
|
||||
@auto_docstring
|
||||
class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None):
|
||||
"""
|
||||
r"""
|
||||
is_encoder_decoder (`Optional`, *optional*):
|
||||
Whether use encoder_decoder for sequence classification. When set to False, only encoder is used.
|
||||
"""
|
||||
@ -1233,7 +1226,6 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
@ -1329,7 +1321,7 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel):
|
||||
@auto_docstring
|
||||
class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None):
|
||||
"""
|
||||
r"""
|
||||
is_encoder_decoder (`Optional`, *optional*):
|
||||
Whether use encoder_decoder for token classification. When set to False, only encoder is used.
|
||||
"""
|
||||
@ -1382,7 +1374,6 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel):
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
||||
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
|
@ -240,6 +240,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user