This commit is contained in:
Vasqu 2025-07-01 17:55:05 +02:00
parent 11de15bda4
commit dd7aeca424
2 changed files with 146 additions and 145 deletions

View File

@ -14,17 +14,16 @@
# limitations under the License. # limitations under the License.
"""PyTorch ALBERT model.""" """PyTorch ALBERT model."""
import math
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Callable, Optional, Union
import torch import torch
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
@ -34,17 +33,20 @@ from ...modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import ( from ...pytorch_utils import (
apply_chunking_to_forward, apply_chunking_to_forward,
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_2_2,
prune_linear_layer, prune_linear_layer,
) )
from ...utils import ModelOutput, auto_docstring, logging from ...utils import ModelOutput, auto_docstring, is_torch_flex_attn_available, logging
from .configuration_albert import AlbertConfig from .configuration_albert import AlbertConfig
if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -199,14 +201,12 @@ class AlbertEmbeddings(nn.Module):
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
) )
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if input_ids is not None: if input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
@ -216,7 +216,7 @@ class AlbertEmbeddings(nn.Module):
seq_length = input_shape[1] seq_length = input_shape[1]
if position_ids is None: if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] position_ids = self.position_ids[:, : seq_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
@ -242,6 +242,64 @@ class AlbertEmbeddings(nn.Module):
return embeddings return embeddings
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
**kwargs,
):
if scaling is None:
scaling = query.size(-1) ** -0.5
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3))
# Relative positional embeddings
if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query":
query_length, key_length = query.shape[2], key.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
if module.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
attn_weights = attn_weights + relative_position_scores
elif module.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key
# Scaling is shifted in case of embeddings being relative
attn_weights = attn_weights * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class AlbertAttention(nn.Module): class AlbertAttention(nn.Module):
def __init__(self, config: AlbertConfig): def __init__(self, config: AlbertConfig):
super().__init__() super().__init__()
@ -250,19 +308,22 @@ class AlbertAttention(nn.Module):
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads}" f"heads ({config.num_attention_heads}"
) )
self.config = config
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.attention_head_size = config.hidden_size // config.num_attention_heads self.attention_head_size = config.hidden_size // config.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5
self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
self.query = nn.Linear(config.hidden_size, self.all_head_size) self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pruned_heads = set() self.pruned_heads = set()
@ -271,11 +332,7 @@ class AlbertAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores self.is_causal = False
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def prune_heads(self, heads: list[int]) -> None: def prune_heads(self, heads: list[int]) -> None:
if len(heads) == 0: if len(heads) == 0:
@ -300,118 +357,49 @@ class AlbertAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False, **kwargs,
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: ) -> tuple[torch.Tensor, torch.Tensor]:
mixed_query_layer = self.query(hidden_states) # determine input shapes
mixed_key_layer = self.key(hidden_states) bsz, tgt_len = hidden_states.shape[:-1]
mixed_value_layer = self.value(hidden_states) src_len = tgt_len
query_layer = self.transpose_for_scores(mixed_query_layer) q_input_shape = (bsz, tgt_len, -1, self.attention_head_size)
key_layer = self.transpose_for_scores(mixed_key_layer) kv_input_shape = (bsz, src_len, -1, self.attention_head_size)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores. # get all proj
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)
attention_scores = attention_scores / math.sqrt(self.attention_head_size) key_layer = self.key(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(*kv_input_shape).transpose(1, 2)
if attention_mask is not None: attention_interface: Callable = eager_attention_forward
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) if self.config._attn_implementation != "eager":
attention_scores = attention_scores + attention_mask if self.position_embedding_type != "absolute":
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": attn_output, attn_weights = attention_interface(
seq_length = hidden_states.size()[1] self,
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) query_layer,
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) key_layer,
distance = position_ids_l - position_ids_r value_layer,
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) attention_mask,
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility dropout=0.0 if not self.training else self.attention_dropout.p,
scaling=self.scaling,
if self.position_embedding_type == "relative_key": head_mask=head_mask,
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) # only for relevant for non-absolute positional embeddings
attention_scores = attention_scores + relative_position_scores use_cache=False,
elif self.position_embedding_type == "relative_key_query": **kwargs,
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.attention_dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(2, 1).flatten(2)
projected_context_layer = self.dense(context_layer)
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
class AlbertSdpaAttention(AlbertAttention):
def __init__(self, config):
super().__init__(config)
self.dropout_prob = config.attention_probs_dropout_prob
self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
if self.position_embedding_type != "absolute" or output_attentions:
logger.warning(
"AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to "
"the eager attention implementation, but specifying the eager implementation will be required from "
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
return super().forward(hidden_states, attention_mask, output_attentions=output_attentions)
batch_size, seq_len, _ = hidden_states.size()
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
attention_output = torch.nn.functional.scaled_dot_product_attention(
query=query_layer,
key=key_layer,
value=value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=False,
) )
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attention_output = attention_output.transpose(1, 2) attn_output = self.dense(attn_output)
attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size) attn_output = self.output_dropout(attn_output)
attn_output = self.LayerNorm(hidden_states + attn_output)
projected_context_layer = self.dense(attention_output) return attn_output, attn_weights
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
return (layernormed_context_layer,)
ALBERT_ATTENTION_CLASSES = {
"eager": AlbertAttention,
"sdpa": AlbertSdpaAttention,
}
class AlbertLayer(nn.Module): class AlbertLayer(nn.Module):
@ -422,7 +410,7 @@ class AlbertLayer(nn.Module):
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config) self.attention = AlbertAttention(config)
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size) self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size) self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
self.activation = ACT2FN[config.hidden_act] self.activation = ACT2FN[config.hidden_act]
@ -433,10 +421,8 @@ class AlbertLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions) attention_output = self.attention(hidden_states, attention_mask, head_mask)
ffn_output = apply_chunking_to_forward( ffn_output = apply_chunking_to_forward(
self.ff_chunk, self.ff_chunk,
@ -473,7 +459,7 @@ class AlbertLayerGroup(nn.Module):
layer_attentions = () layer_attentions = ()
for layer_index, albert_layer in enumerate(self.albert_layers): for layer_index, albert_layer in enumerate(self.albert_layers):
layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions) layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index])
hidden_states = layer_output[0] hidden_states = layer_output[0]
if output_attentions: if output_attentions:
@ -548,7 +534,9 @@ class AlbertPreTrainedModel(PreTrainedModel):
config_class = AlbertConfig config_class = AlbertConfig
load_tf_weights = load_tf_weights_in_albert load_tf_weights = load_tf_weights_in_albert
base_model_prefix = "albert" base_model_prefix = "albert"
_supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_flex_attn = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
@ -691,27 +679,16 @@ class AlbertModel(AlbertPreTrainedModel):
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
) )
use_sdpa_attention_mask = ( attention_mask = self._update_full_mask(
self.attn_implementation == "sdpa" attention_mask,
and self.position_embedding_type == "absolute" embedding_output
and head_mask is None
and not output_attentions
) )
if use_sdpa_attention_mask:
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
extended_attention_mask, attention_mask,
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
@ -732,6 +709,29 @@ class AlbertModel(AlbertPreTrainedModel):
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
) )
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
def _update_full_mask(
self,
attention_mask: Union[torch.Tensor, None],
inputs_embeds: torch.Tensor,
):
if attention_mask is not None:
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa":
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
elif self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
return attention_mask
@auto_docstring( @auto_docstring(
custom_intro=""" custom_intro="""

View File

@ -53,12 +53,12 @@ class AlbertModelTester:
use_labels=True, use_labels=True,
vocab_size=32, vocab_size=32,
embedding_size=8, embedding_size=8,
hidden_size=12, hidden_size=16,
num_hidden_layers=2, num_hidden_layers=2,
# this needs to be the same as `num_hidden_layers`! # this needs to be the same as `num_hidden_layers`!
num_hidden_groups=2, num_hidden_groups=2,
num_attention_heads=4, num_attention_heads=4,
intermediate_size=16, intermediate_size=20,
hidden_act="gelu", hidden_act="gelu",
hidden_dropout_prob=0.1, hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
@ -259,7 +259,7 @@ class AlbertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else {} else {}
) )
fx_compatible = True fx_compatible = False # will not be maintained
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
@ -310,6 +310,7 @@ class AlbertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]: for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type config_and_inputs[0].position_embedding_type = type
config_and_inputs[0]._attn_implementation = "eager"
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
@slow @slow