mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[ESM] Add flash-attention-2 backend for ESM-2 (#38023)
* Add flash-attention-2 backend for ESM-2 Signed-off-by: Peter St. John <pstjohn@nvidia.com> * update extended_attention_mask for fa2 Signed-off-by: Peter St. John <pstjohn@nvidia.com> * add test_flash_attn_2_equivalence test Signed-off-by: Peter St. John <pstjohn@nvidia.com> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
This commit is contained in:
parent
7b5e327c6e
commit
d69945e5fc
@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -30,10 +31,14 @@ from ...modeling_outputs import (
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import auto_docstring, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
|
||||
from .configuration_esm import EsmConfig
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -111,8 +116,8 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
|
||||
|
||||
return (
|
||||
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
|
||||
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
|
||||
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
|
||||
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
|
||||
)
|
||||
|
||||
|
||||
@ -245,6 +250,8 @@ class EsmEmbeddings(nn.Module):
|
||||
class EsmSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||
@ -393,10 +400,128 @@ class EsmSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EsmFlashAttention2(EsmSelfAttention):
|
||||
"""
|
||||
ESM flash attention module. This module inherits from `EsmSelfAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__(config, position_embedding_type=position_embedding_type)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
# Flash attention doesn't support output_attentions or cross attention
|
||||
if output_attentions or head_mask is not None or encoder_hidden_states is not None:
|
||||
logger.warning_once(
|
||||
"EsmFlashAttention2 does not support output_attentions, head_mask, or cross_attention. "
|
||||
"Falling back to the manual attention implementation. This warning can be removed using "
|
||||
'the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
if past_key_value is not None:
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32.
|
||||
input_dtype = query_layer.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.query.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_layer = query_layer.to(target_dtype)
|
||||
key_layer = key_layer.to(target_dtype)
|
||||
value_layer = value_layer.to(target_dtype)
|
||||
|
||||
# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
|
||||
# ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
|
||||
# but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
|
||||
# ESM code and fix rotary embeddings.
|
||||
query_layer = query_layer * self.attention_head_size**-0.5
|
||||
|
||||
if self.position_embedding_type == "rotary":
|
||||
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
||||
elif self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
raise ValueError(f"ESM flash attention does not support {self.position_embedding_type} embeddings")
|
||||
|
||||
# It would likely be faster to change self.transpose_for_scores to output the correct
|
||||
# dimensions for flash_attention_2, but that would also mean changing the rotary embedding
|
||||
# functions. Here we just permute the dimensions to match the expected input.
|
||||
attn_output = _flash_attention_forward(
|
||||
query_layer.permute(0, 2, 1, 3),
|
||||
key_layer.permute(0, 2, 1, 3),
|
||||
value_layer.permute(0, 2, 1, 3),
|
||||
attention_mask,
|
||||
query_length=q_len,
|
||||
is_causal=self.is_decoder,
|
||||
softmax_scale=1.0,
|
||||
dropout=self.dropout_prob if self.training else 0.0,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
outputs = (attn_output, None)
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
ESM_ATTENTION_CLASSES = {
|
||||
"eager": EsmSelfAttention,
|
||||
"flash_attention_2": EsmFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class EsmAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = EsmSelfAttention(config)
|
||||
self.self = ESM_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.output = EsmSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@ -673,6 +798,7 @@ class EsmPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "esm"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead
|
||||
def _init_weights(self, module):
|
||||
@ -806,9 +932,13 @@ class EsmModel(EsmPreTrainedModel):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
extended_attention_mask = attention_mask
|
||||
|
||||
else:
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
@ -1980,6 +1980,7 @@ class EsmFoldingTrunk(nn.Module):
|
||||
)
|
||||
class EsmForProteinFolding(EsmPreTrainedModel):
|
||||
_no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
|
||||
_supports_flash_attn_2 = False
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -2050,6 +2051,7 @@ class EsmForProteinFolding(EsmPreTrainedModel):
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
masking_pattern: Optional[torch.Tensor] = None,
|
||||
num_recycles: Optional[int] = None,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
) -> EsmForProteinFoldingOutput:
|
||||
r"""
|
||||
masking_pattern (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
|
@ -13,10 +13,22 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch ESM model."""
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import EsmConfig, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, require_bitsandbytes, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
is_flaky,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
@ -59,6 +71,7 @@ class EsmModelTester:
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
position_embedding_type="rotary",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -82,6 +95,7 @@ class EsmModelTester:
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.position_embedding_type = position_embedding_type
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@ -116,6 +130,7 @@ class EsmModelTester:
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
position_embedding_type=self.position_embedding_type,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
@ -296,6 +311,39 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@is_flaky()
|
||||
@slow
|
||||
def test_flash_attn_2_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(reason="Model does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
dummy_input = dummy_input.to(torch_device)
|
||||
outputs = model(dummy_input, output_hidden_states=True)
|
||||
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
|
||||
|
||||
logits = outputs.hidden_states[-1]
|
||||
logits_fa = outputs_fa.hidden_states[-1]
|
||||
|
||||
torch.testing.assert_close(logits_fa, logits, atol=1e-2, rtol=1e-3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
Loading…
Reference in New Issue
Block a user