[ESM] Add flash-attention-2 backend for ESM-2 (#38023)

* Add flash-attention-2 backend for ESM-2

Signed-off-by: Peter St. John <pstjohn@nvidia.com>

* update extended_attention_mask for fa2

Signed-off-by: Peter St. John <pstjohn@nvidia.com>

* add test_flash_attn_2_equivalence test

Signed-off-by: Peter St. John <pstjohn@nvidia.com>

---------

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
This commit is contained in:
Peter St. John 2025-05-16 07:11:56 -06:00 committed by GitHub
parent 7b5e327c6e
commit d69945e5fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 188 additions and 8 deletions

View File

@ -1,5 +1,6 @@
# coding=utf-8
# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -30,10 +31,14 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import auto_docstring, logging
from ...utils import auto_docstring, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
from .configuration_esm import EsmConfig
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
@ -111,8 +116,8 @@ class RotaryEmbedding(torch.nn.Module):
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
)
@ -245,6 +250,8 @@ class EsmEmbeddings(nn.Module):
class EsmSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
@ -393,10 +400,128 @@ class EsmSelfOutput(nn.Module):
return hidden_states
class EsmFlashAttention2(EsmSelfAttention):
"""
ESM flash attention module. This module inherits from `EsmSelfAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, config, position_embedding_type=None):
super().__init__(config, position_embedding_type=position_embedding_type)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
self.dropout_prob = config.attention_probs_dropout_prob
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
# Flash attention doesn't support output_attentions or cross attention
if output_attentions or head_mask is not None or encoder_hidden_states is not None:
logger.warning_once(
"EsmFlashAttention2 does not support output_attentions, head_mask, or cross_attention. "
"Falling back to the manual attention implementation. This warning can be removed using "
'the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
bsz, q_len, _ = hidden_states.size()
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
if past_key_value is not None:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
input_dtype = query_layer.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.query.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_layer = query_layer.to(target_dtype)
key_layer = key_layer.to(target_dtype)
value_layer = value_layer.to(target_dtype)
# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
# ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
# but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
# ESM code and fix rotary embeddings.
query_layer = query_layer * self.attention_head_size**-0.5
if self.position_embedding_type == "rotary":
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
elif self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
raise ValueError(f"ESM flash attention does not support {self.position_embedding_type} embeddings")
# It would likely be faster to change self.transpose_for_scores to output the correct
# dimensions for flash_attention_2, but that would also mean changing the rotary embedding
# functions. Here we just permute the dimensions to match the expected input.
attn_output = _flash_attention_forward(
query_layer.permute(0, 2, 1, 3),
key_layer.permute(0, 2, 1, 3),
value_layer.permute(0, 2, 1, 3),
attention_mask,
query_length=q_len,
is_causal=self.is_decoder,
softmax_scale=1.0,
dropout=self.dropout_prob if self.training else 0.0,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
outputs = (attn_output, None)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
ESM_ATTENTION_CLASSES = {
"eager": EsmSelfAttention,
"flash_attention_2": EsmFlashAttention2,
}
class EsmAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = EsmSelfAttention(config)
self.self = ESM_ATTENTION_CLASSES[config._attn_implementation](config)
self.output = EsmSelfOutput(config)
self.pruned_heads = set()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -673,6 +798,7 @@ class EsmPreTrainedModel(PreTrainedModel):
base_model_prefix = "esm"
supports_gradient_checkpointing = True
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
_supports_flash_attn_2 = True
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead
def _init_weights(self, module):
@ -806,9 +932,13 @@ class EsmModel(EsmPreTrainedModel):
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
if self.config._attn_implementation == "flash_attention_2":
extended_attention_mask = attention_mask
else:
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@ -1980,6 +1980,7 @@ class EsmFoldingTrunk(nn.Module):
)
class EsmForProteinFolding(EsmPreTrainedModel):
_no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
_supports_flash_attn_2 = False
def __init__(self, config):
super().__init__(config)
@ -2050,6 +2051,7 @@ class EsmForProteinFolding(EsmPreTrainedModel):
position_ids: Optional[torch.Tensor] = None,
masking_pattern: Optional[torch.Tensor] = None,
num_recycles: Optional[int] = None,
output_hidden_states: Optional[bool] = False,
) -> EsmForProteinFoldingOutput:
r"""
masking_pattern (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):

View File

@ -13,10 +13,22 @@
# limitations under the License.
"""Testing suite for the PyTorch ESM model."""
import tempfile
import unittest
import pytest
from transformers import EsmConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_bitsandbytes, require_torch, slow, torch_device
from transformers.testing_utils import (
TestCasePlus,
is_flaky,
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@ -59,6 +71,7 @@ class EsmModelTester:
num_labels=3,
num_choices=4,
scope=None,
position_embedding_type="rotary",
):
self.parent = parent
self.batch_size = batch_size
@ -82,6 +95,7 @@ class EsmModelTester:
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.position_embedding_type = position_embedding_type
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -116,6 +130,7 @@ class EsmModelTester:
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
position_embedding_type=self.position_embedding_type,
)
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
@ -296,6 +311,39 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_resize_tokens_embeddings(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@is_flaky()
@slow
def test_flash_attn_2_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(reason="Model does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
)
model_fa.to(torch_device)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
model.to(torch_device)
dummy_input = inputs_dict[model_class.main_input_name]
dummy_input = dummy_input.to(torch_device)
outputs = model(dummy_input, output_hidden_states=True)
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
logits = outputs.hidden_states[-1]
logits_fa = outputs_fa.hidden_states[-1]
torch.testing.assert_close(logits_fa, logits, atol=1e-2, rtol=1e-3)
@slow
@require_torch