mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
358 lines
16 KiB
Python
358 lines
16 KiB
Python
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from examples/modular-transformers/modular_from_uppercase_model.py.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# modular_from_uppercase_model.py file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ...activations import ACT2FN
|
|
from ...pytorch_utils import is_torch_greater_or_equal_than_2_2
|
|
from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
|
|
from .configuration_from_uppercase_model import FromUppercaseModelConfig
|
|
|
|
|
|
if is_flash_attn_2_available():
|
|
from ...modeling_flash_attention_utils import _flash_attention_forward
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class FromUppercaseModelAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
if self.head_dim * self.num_heads != self.embed_dim:
|
|
raise ValueError(
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
f" {self.num_heads})."
|
|
)
|
|
self.scale = self.head_dim**-0.5
|
|
self.dropout = config.attention_dropout
|
|
|
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
bsz, tgt_len, embed_dim = hidden_states.size()
|
|
|
|
# get query proj
|
|
query_states = self.q_proj(hidden_states) * self.scale
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
|
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
|
key_states = key_states.view(*proj_shape)
|
|
value_states = value_states.view(*proj_shape)
|
|
|
|
src_len = key_states.size(1)
|
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
|
|
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
|
raise ValueError(
|
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
|
f" {attn_weights.size()}"
|
|
)
|
|
|
|
# apply the causal_attention_mask first
|
|
if causal_attention_mask is not None:
|
|
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
|
f" {causal_attention_mask.size()}"
|
|
)
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
if attention_mask is not None:
|
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
)
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
|
|
if output_attentions:
|
|
# this operation is a bit akward, but it's required to
|
|
# make sure that attn_weights keeps its gradient.
|
|
# In order to do so, attn_weights have to reshaped
|
|
# twice and have to be reused in the following
|
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
|
else:
|
|
attn_weights_reshaped = None
|
|
|
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
|
|
|
attn_output = torch.bmm(attn_probs, value_states)
|
|
|
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
|
raise ValueError(
|
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
|
f" {attn_output.size()}"
|
|
)
|
|
|
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
attn_output = attn_output.transpose(1, 2)
|
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output, attn_weights_reshaped
|
|
|
|
|
|
class FromUppercaseModelFlashAttention2(FromUppercaseModelAttention):
|
|
"""
|
|
FromUppercaseModelAttention flash attention module. This module inherits from `FromUppercaseModelAttention` as the weights of the module stays
|
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
|
flash attention and deal with padding tokens in case the input contains any of them.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
|
|
|
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
output_attentions = False
|
|
|
|
batch_size, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
# Flash attention requires the input to have the shape
|
|
# batch_size x seq_length x head_dim x hidden_dim
|
|
# therefore we just need to keep the original shape
|
|
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
|
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
|
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
|
|
|
dropout_rate = self.dropout if self.training else 0.0
|
|
|
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
|
# in fp32.
|
|
|
|
input_dtype = query_states.dtype
|
|
if input_dtype == torch.float32:
|
|
if torch.is_autocast_enabled():
|
|
target_dtype = torch.get_autocast_gpu_dtype()
|
|
# Handle the case where the model is quantized
|
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
|
target_dtype = self.config._pre_quantization_dtype
|
|
else:
|
|
target_dtype = self.q_proj.weight.dtype
|
|
|
|
logger.warning_once(
|
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
|
f" {target_dtype}."
|
|
)
|
|
|
|
query_states = query_states.to(target_dtype)
|
|
key_states = key_states.to(target_dtype)
|
|
value_states = value_states.to(target_dtype)
|
|
|
|
attn_output = _flash_attention_forward(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
q_len,
|
|
dropout=dropout_rate,
|
|
is_causal=causal_attention_mask is not None,
|
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights
|
|
|
|
|
|
class FromUppercaseModelSdpaAttention(FromUppercaseModelAttention):
|
|
"""
|
|
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
|
`FromUppercaseModelAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
|
SDPA API.
|
|
"""
|
|
|
|
# Adapted from FromUppercaseModelAttention.forward
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
causal_attention_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
if output_attentions:
|
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
|
logger.warning_once(
|
|
"FromUppercaseModelModel is using FromUppercaseModelSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
|
|
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
|
|
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
|
|
'be removed using the argument `attn_implementation="eager"` when loading the model.'
|
|
)
|
|
return super().forward(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
causal_attention_mask=causal_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
# FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask`
|
|
if attention_mask is not None and causal_attention_mask is not None:
|
|
attn_mask = attention_mask + causal_attention_mask
|
|
elif causal_attention_mask is not None:
|
|
attn_mask = causal_attention_mask
|
|
else:
|
|
attn_mask = attention_mask
|
|
|
|
bsz, tgt_len, embed_dim = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
|
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
|
if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
|
|
query_states = query_states.contiguous()
|
|
key_states = key_states.contiguous()
|
|
value_states = value_states.contiguous()
|
|
|
|
# FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask` sequentially.
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attn_mask=attn_mask,
|
|
dropout_p=self.dropout if self.training else 0.0,
|
|
scale=self.scale,
|
|
)
|
|
|
|
attn_output = attn_output.transpose(1, 2)
|
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output, None
|
|
|
|
|
|
class FromUppercaseModelMLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.activation_fn = ACT2FN[config.hidden_act]
|
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
FROM_UPPERCASE_MODEL_ATTENTION_CLASSES = {
|
|
"eager": FromUppercaseModelAttention,
|
|
"sdpa": FromUppercaseModelSdpaAttention,
|
|
"flash_attention_2": FromUppercaseModelFlashAttention2,
|
|
}
|
|
|
|
|
|
class FromUppercaseModelEncoderLayer(nn.Module):
|
|
def __init__(self, config: FromUppercaseModelConfig):
|
|
super().__init__()
|
|
self.embed_dim = config.hidden_size
|
|
self.self_attn = FROM_UPPERCASE_MODEL_ATTENTION_CLASSES[config._attn_implementation](config)
|
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
self.mlp = FromUppercaseModelMLP(config)
|
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
causal_attention_mask: torch.Tensor,
|
|
output_attentions: Optional[bool] = False,
|
|
) -> tuple[torch.FloatTensor]:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
`(config.encoder_attention_heads,)`.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
"""
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
hidden_states, attn_weights = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
causal_attention_mask=causal_attention_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
|
|
return outputs
|