This commit is contained in:
Vasqu 2025-07-01 17:55:05 +02:00
parent 11de15bda4
commit dd7aeca424
2 changed files with 146 additions and 145 deletions

View File

@ -14,17 +14,16 @@
# limitations under the License.
"""PyTorch ALBERT model."""
import math
import os
from dataclasses import dataclass
from typing import Optional, Union
from typing import Callable, Optional, Union
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
@ -34,17 +33,20 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_2_2,
prune_linear_layer,
)
from ...utils import ModelOutput, auto_docstring, logging
from ...utils import ModelOutput, auto_docstring, is_torch_flex_attn_available, logging
from .configuration_albert import AlbertConfig
if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@ -199,14 +201,12 @@ class AlbertEmbeddings(nn.Module):
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
)
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
@ -216,7 +216,7 @@ class AlbertEmbeddings(nn.Module):
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
position_ids = self.position_ids[:, : seq_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
@ -242,6 +242,64 @@ class AlbertEmbeddings(nn.Module):
return embeddings
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
head_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
**kwargs,
):
if scaling is None:
scaling = query.size(-1) ** -0.5
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3))
# Relative positional embeddings
if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query":
query_length, key_length = query.shape[2], key.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
if module.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
attn_weights = attn_weights + relative_position_scores
elif module.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key
# Scaling is shifted in case of embeddings being relative
attn_weights = attn_weights * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class AlbertAttention(nn.Module):
def __init__(self, config: AlbertConfig):
super().__init__()
@ -250,19 +308,22 @@ class AlbertAttention(nn.Module):
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads}"
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5
self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pruned_heads = set()
@ -271,11 +332,7 @@ class AlbertAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
self.is_causal = False
def prune_heads(self, heads: list[int]) -> None:
if len(heads) == 0:
@ -300,118 +357,49 @@ class AlbertAttention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
src_len = tgt_len
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
q_input_shape = (bsz, tgt_len, -1, self.attention_head_size)
kv_input_shape = (bsz, src_len, -1, self.attention_head_size)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# get all proj
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(*kv_input_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(*kv_input_shape).transpose(1, 2)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.position_embedding_type != "absolute":
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.attention_dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(2, 1).flatten(2)
projected_context_layer = self.dense(context_layer)
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
class AlbertSdpaAttention(AlbertAttention):
def __init__(self, config):
super().__init__(config)
self.dropout_prob = config.attention_probs_dropout_prob
self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
if self.position_embedding_type != "absolute" or output_attentions:
logger.warning(
"AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to "
"the eager attention implementation, but specifying the eager implementation will be required from "
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
return super().forward(hidden_states, attention_mask, output_attentions=output_attentions)
batch_size, seq_len, _ = hidden_states.size()
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
attention_output = torch.nn.functional.scaled_dot_product_attention(
query=query_layer,
key=key_layer,
value=value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=False,
attn_output, attn_weights = attention_interface(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout.p,
scaling=self.scaling,
head_mask=head_mask,
# only for relevant for non-absolute positional embeddings
use_cache=False,
**kwargs,
)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attention_output = attention_output.transpose(1, 2)
attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size)
attn_output = self.dense(attn_output)
attn_output = self.output_dropout(attn_output)
attn_output = self.LayerNorm(hidden_states + attn_output)
projected_context_layer = self.dense(attention_output)
projected_context_layer_dropout = self.output_dropout(projected_context_layer)
layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
return (layernormed_context_layer,)
ALBERT_ATTENTION_CLASSES = {
"eager": AlbertAttention,
"sdpa": AlbertSdpaAttention,
}
return attn_output, attn_weights
class AlbertLayer(nn.Module):
@ -422,7 +410,7 @@ class AlbertLayer(nn.Module):
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config)
self.attention = AlbertAttention(config)
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
self.activation = ACT2FN[config.hidden_act]
@ -433,10 +421,8 @@ class AlbertLayer(nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self.attention(hidden_states, attention_mask, head_mask)
ffn_output = apply_chunking_to_forward(
self.ff_chunk,
@ -473,7 +459,7 @@ class AlbertLayerGroup(nn.Module):
layer_attentions = ()
for layer_index, albert_layer in enumerate(self.albert_layers):
layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index])
hidden_states = layer_output[0]
if output_attentions:
@ -548,7 +534,9 @@ class AlbertPreTrainedModel(PreTrainedModel):
config_class = AlbertConfig
load_tf_weights = load_tf_weights_in_albert
base_model_prefix = "albert"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
def _init_weights(self, module):
"""Initialize the weights."""
@ -691,27 +679,16 @@ class AlbertModel(AlbertPreTrainedModel):
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
use_sdpa_attention_mask = (
self.attn_implementation == "sdpa"
and self.position_embedding_type == "absolute"
and head_mask is None
and not output_attentions
attention_mask = self._update_full_mask(
attention_mask,
embedding_output
)
if use_sdpa_attention_mask:
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@ -732,6 +709,29 @@ class AlbertModel(AlbertPreTrainedModel):
attentions=encoder_outputs.attentions,
)
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
def _update_full_mask(
self,
attention_mask: Union[torch.Tensor, None],
inputs_embeds: torch.Tensor,
):
if attention_mask is not None:
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa":
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
elif self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
return attention_mask
@auto_docstring(
custom_intro="""

View File

@ -53,12 +53,12 @@ class AlbertModelTester:
use_labels=True,
vocab_size=32,
embedding_size=8,
hidden_size=12,
hidden_size=16,
num_hidden_layers=2,
# this needs to be the same as `num_hidden_layers`!
num_hidden_groups=2,
num_attention_heads=4,
intermediate_size=16,
intermediate_size=20,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
@ -259,7 +259,7 @@ class AlbertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
if is_torch_available()
else {}
)
fx_compatible = True
fx_compatible = False # will not be maintained
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
@ -310,6 +310,7 @@ class AlbertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
config_and_inputs[0]._attn_implementation = "eager"
self.model_tester.create_and_check_model(*config_and_inputs)
@slow