diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 7285c8ba569..cb7c0a925b1 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -14,17 +14,16 @@ # limitations under the License. """PyTorch ALBERT model.""" -import math import os from dataclasses import dataclass -from typing import Optional, Union +from typing import Callable, Optional, Union import torch from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -34,17 +33,20 @@ from ...modeling_outputs import ( SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import ( apply_chunking_to_forward, find_pruneable_heads_and_indices, - is_torch_greater_or_equal_than_2_2, prune_linear_layer, ) -from ...utils import ModelOutput, auto_docstring, logging +from ...utils import ModelOutput, auto_docstring, is_torch_flex_attn_available, logging from .configuration_albert import AlbertConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -199,14 +201,12 @@ class AlbertEmbeddings(nn.Module): "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False ) - # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward def forward( self, input_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values_length: int = 0, ) -> torch.Tensor: if input_ids is not None: input_shape = input_ids.size() @@ -216,7 +216,7 @@ class AlbertEmbeddings(nn.Module): seq_length = input_shape[1] if position_ids is None: - position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + position_ids = self.position_ids[:, : seq_length] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves @@ -242,6 +242,64 @@ class AlbertEmbeddings(nn.Module): return embeddings +# Copied from transformers.models.bert.modeling_bert.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) + + # Relative positional embeddings + if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": + query_length, key_length = query.shape[2], key.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility + + if module.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) + attn_weights = attn_weights + relative_position_scores + elif module.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) + attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key + + # Scaling is shifted in case of embeddings being relative + attn_weights = attn_weights * scaling + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class AlbertAttention(nn.Module): def __init__(self, config: AlbertConfig): super().__init__() @@ -250,19 +308,22 @@ class AlbertAttention(nn.Module): f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads}" ) + self.config = config self.num_attention_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.attention_head_size = config.hidden_size // config.num_attention_heads self.all_head_size = self.num_attention_heads * self.attention_head_size + self.scaling = self.attention_head_size**-0.5 + + self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.output_dropout = nn.Dropout(config.hidden_dropout_prob) self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) - - self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.output_dropout = nn.Dropout(config.hidden_dropout_prob) self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pruned_heads = set() @@ -271,11 +332,7 @@ class AlbertAttention(nn.Module): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) - # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) + self.is_causal = False def prune_heads(self, heads: list[int]) -> None: if len(heads) == 0: @@ -300,118 +357,49 @@ class AlbertAttention(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, - output_attentions: bool = False, - ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = tgt_len - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + q_input_shape = (bsz, tgt_len, -1, self.attention_head_size) + kv_input_shape = (bsz, src_len, -1, self.attention_head_size) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # get all proj + query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2) + key_layer = self.key(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_layer = self.value(hidden_states).view(*kv_input_shape).transpose(1, 2) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - attention_scores = attention_scores + attention_mask + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.position_embedding_type != "absolute": + raise ValueError( + f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " + 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' + ) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.transpose(2, 1).flatten(2) - - projected_context_layer = self.dense(context_layer) - projected_context_layer_dropout = self.output_dropout(projected_context_layer) - layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout) - return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,) - - -class AlbertSdpaAttention(AlbertAttention): - def __init__(self, config): - super().__init__(config) - self.dropout_prob = config.attention_probs_dropout_prob - self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2 - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - output_attentions: bool = False, - ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: - if self.position_embedding_type != "absolute" or output_attentions: - logger.warning( - "AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to " - "the eager attention implementation, but specifying the eager implementation will be required from " - "Transformers version v5.0.0 onwards. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - return super().forward(hidden_states, attention_mask, output_attentions=output_attentions) - - batch_size, seq_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom - # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. - # Reference: https://github.com/pytorch/pytorch/issues/112577 - if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: - query_layer = query_layer.contiguous() - key_layer = key_layer.contiguous() - value_layer = value_layer.contiguous() - - attention_output = torch.nn.functional.scaled_dot_product_attention( - query=query_layer, - key=key_layer, - value=value_layer, - attn_mask=attention_mask, - dropout_p=self.dropout_prob if self.training else 0.0, - is_causal=False, + attn_output, attn_weights = attention_interface( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout.p, + scaling=self.scaling, + head_mask=head_mask, + # only for relevant for non-absolute positional embeddings + use_cache=False, + **kwargs, ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() - attention_output = attention_output.transpose(1, 2) - attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size) + attn_output = self.dense(attn_output) + attn_output = self.output_dropout(attn_output) + attn_output = self.LayerNorm(hidden_states + attn_output) - projected_context_layer = self.dense(attention_output) - projected_context_layer_dropout = self.output_dropout(projected_context_layer) - layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout) - return (layernormed_context_layer,) - - -ALBERT_ATTENTION_CLASSES = { - "eager": AlbertAttention, - "sdpa": AlbertSdpaAttention, -} + return attn_output, attn_weights class AlbertLayer(nn.Module): @@ -422,7 +410,7 @@ class AlbertLayer(nn.Module): self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config) + self.attention = AlbertAttention(config) self.ffn = nn.Linear(config.hidden_size, config.intermediate_size) self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size) self.activation = ACT2FN[config.hidden_act] @@ -433,10 +421,8 @@ class AlbertLayer(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, - output_attentions: bool = False, - output_hidden_states: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: - attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.attention(hidden_states, attention_mask, head_mask) ffn_output = apply_chunking_to_forward( self.ff_chunk, @@ -473,7 +459,7 @@ class AlbertLayerGroup(nn.Module): layer_attentions = () for layer_index, albert_layer in enumerate(self.albert_layers): - layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions) + layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index]) hidden_states = layer_output[0] if output_attentions: @@ -548,7 +534,9 @@ class AlbertPreTrainedModel(PreTrainedModel): config_class = AlbertConfig load_tf_weights = load_tf_weights_in_albert base_model_prefix = "albert" + _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights.""" @@ -691,27 +679,16 @@ class AlbertModel(AlbertPreTrainedModel): input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds ) - use_sdpa_attention_mask = ( - self.attn_implementation == "sdpa" - and self.position_embedding_type == "absolute" - and head_mask is None - and not output_attentions + attention_mask = self._update_full_mask( + attention_mask, + embedding_output ) - if use_sdpa_attention_mask: - extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( - attention_mask, embedding_output.dtype, tgt_len=seq_length - ) - else: - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) encoder_outputs = self.encoder( embedding_output, - extended_attention_mask, + attention_mask, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -732,6 +709,29 @@ class AlbertModel(AlbertPreTrainedModel): attentions=encoder_outputs.attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + @auto_docstring( custom_intro=""" diff --git a/tests/models/albert/test_modeling_albert.py b/tests/models/albert/test_modeling_albert.py index f0440fb349d..1fcc3d417d5 100644 --- a/tests/models/albert/test_modeling_albert.py +++ b/tests/models/albert/test_modeling_albert.py @@ -53,12 +53,12 @@ class AlbertModelTester: use_labels=True, vocab_size=32, embedding_size=8, - hidden_size=12, + hidden_size=16, num_hidden_layers=2, # this needs to be the same as `num_hidden_layers`! num_hidden_groups=2, num_attention_heads=4, - intermediate_size=16, + intermediate_size=20, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, @@ -259,7 +259,7 @@ class AlbertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): if is_torch_available() else {} ) - fx_compatible = True + fx_compatible = False # will not be maintained # special case for ForPreTraining model def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): @@ -310,6 +310,7 @@ class AlbertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() for type in ["absolute", "relative_key", "relative_key_query"]: config_and_inputs[0].position_embedding_type = type + config_and_inputs[0]._attn_implementation = "eager" self.model_tester.create_and_check_model(*config_and_inputs) @slow